Compare commits

...

133 Commits

Author SHA1 Message Date
Richard Kuo (Danswer)
bdaaebe955 use re.search instead of re.match (which searches from start of string only) 2024-08-07 20:55:18 -07:00
pablodanswer
9eb48ca2c3 account for empty links + fix quote processing 2024-08-07 20:55:18 -07:00
rkuo-danswer
509fa3a994 add postgres configuration (#2076) 2024-08-08 00:13:59 +00:00
pablodanswer
5097c7f284 Handle saved search docs in eval flow (#2075) 2024-08-07 16:18:34 -07:00
pablodanswer
c4e1c62c00 Admin UX updates (#2057) 2024-08-07 14:55:16 -07:00
pablodanswer
eab82782ca Add proper delay for assistant switching (#2070)
* add proper delay for assistant switching

* persist input if possible
2024-08-07 14:46:15 -07:00
pablodanswer
53d976234a proper new chat button redirects (#2074) 2024-08-07 14:44:42 -07:00
pablodanswer
44d8e34b5a Improve seeding (includes all enterprise features) (#2065) 2024-08-07 10:44:33 -07:00
pablodanswer
d2e16a599d Improve shared chat page (#2066)
* improve look of shared chat page

* remove log

* cleaner display

* add initializing loader to shared chat page

* updated danswer loaders (for prism)

* remove default share
2024-08-07 16:13:55 +00:00
pablodanswer
291e6c4198 somewhat clearer API errors (#2064) 2024-08-07 03:04:26 +00:00
Chris Weaver
bb7e1d6e55 Add integration tests for document set syncing (#1904) 2024-08-06 18:00:19 -07:00
rkuo-danswer
fcc4c30ead don't skip the start of the json answer value (#2067) 2024-08-06 23:59:13 +00:00
pablodanswer
f20984ea1d Don't persist error perennially (#2061)
* don't persist error perennially

* proper functionality

* remove logs

* remove another log

* add comments for clarity + reverse conditional

* add comment back

* remove comment
2024-08-06 23:09:25 +00:00
pablodanswer
e0f0cfd92e Ensure relevance functions for selected docs (#2063)
* ensure relevance functions for selected docs

* remove logs

* remove log
2024-08-06 21:06:44 +00:00
pablodanswer
57aec7d02a doc sidebar width fix 2024-08-06 13:48:47 -07:00
pablodanswer
6350219143 Add proper default temperature + overrides (#2059)
* add proper default temperature + overrides

* remove unclear commment

* ammend defaults + include internet serach
2024-08-06 19:57:14 +00:00
pablodanswer
3bc2cf9946 update tool display bubbles to have cursor-dfeault 2024-08-06 12:49:42 -07:00
pablodanswer
7f7452dc98 Whitelabelling consistency (#2058)
* add white labelling to admin sidebar

* even more consistency
2024-08-06 19:45:38 +00:00
pablodanswer
dc2a50034d Clean chat banner (#2056)
* fully functional

* formatting

* ensure consistency with large logos

* ensure mobile support
2024-08-06 19:44:14 +00:00
pablodanswer
ab564a9ec8 Add cleaner loading / streaming for image loading (#2055)
* add image loading

* clean

* add loading skeleton

* clean up

* clearer comments
2024-08-06 19:28:48 +00:00
rkuo-danswer
cc3856ef6d enforce index attempt deduping on secondary indexing. (#2054)
* enforce index attempt deduping on secondary indexing.

* black fix

* typo fixes

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-08-06 17:45:16 +00:00
Yuhong Sun
a8a4ad9546 Chunk Filter Metadata Format (#2053) 2024-08-05 15:12:36 -07:00
pablodanswer
5bfdecacad fix assistant drag transform effect (#2052) 2024-08-05 14:53:38 -07:00
pablodanswer
0bde66a888 remove "quotes" section (#2049) 2024-08-05 18:51:43 +00:00
pablodanswer
5825d01d53 Better assistant interactions + UI (#2029)
* add assistnat re-ordering, selections, etc.

* squash

* remove unnecessary comment

* squash

* adapt dragging for all IDs + smoother animation + consistency

* fix minor typing issue

* fix minor typing issue

* remove logs
2024-08-05 18:22:57 +00:00
pablodanswer
cd22cca4e8 remove non-EE public connector options 2024-08-05 11:14:20 -07:00
pablodanswer
a3ea217f40 ensure consistency of answers + update llm relevance prompting (#2045) 2024-08-05 08:27:15 -07:00
pablodanswer
66e4dded91 Add properly random icons to assistant creation page (#2044) 2024-08-04 23:30:17 -07:00
pablodanswer
6d67d472cd Add answers to search (#2020) 2024-08-04 23:02:55 -07:00
Weves
76b7792e69 Harden embedding calls 2024-08-04 15:11:45 -07:00
Chris Weaver
9d7100a287 Fix secondary index attempts showing up as the primary index status + scheduling while in-progress (#2039) 2024-08-04 13:29:44 -07:00
pablodanswer
876feecd6f Fix code pasting formatting (#2033)
* fix pasting formatting

* add back small comments
2024-08-04 09:56:48 -07:00
pablodanswer
0261d689dc Various Admin Page + User Flow Improvements (#1987) 2024-08-03 18:09:46 -07:00
pablodanswer
aa4a00cbc2 fix minor html error (#2034) 2024-08-03 12:40:07 -07:00
Nathan Schwerdfeger
52c505c210 Remove partially implemented reply cancellation (#2031)
* fix: remove partially implemented response cancellation

* feat: notify user when unsupported chat cancellation is requested

* fix: correct ChatInputBar streaming detection logic
2024-08-03 18:12:04 +00:00
pablodanswer
ed455394fc detect foreign key composition sessions (#2024) 2024-08-02 17:26:57 +00:00
hagen-danswer
57cc53ab94 Added content tags to zendesk connector (#2017) 2024-08-02 10:09:53 -07:00
rkuo-danswer
6a61331cba Feature/log despam (#2022)
* move a lot of log spam to debug level. Consolidate some info level logging

* reformat more indexing logging
2024-08-02 15:28:53 +00:00
Weves
51731ad0dd Fix issue where large docs/batches break openai embedding 2024-08-02 01:07:09 -07:00
rkuo-danswer
f280586e68 pass function to Process correctly instead of running it inline (#2018)
* pass function to Process correctly instead of running it inline

* mypy fixes and pass back return result (even tho we don't use it right now)
2024-08-02 00:06:35 +00:00
hagen-danswer
e31d6be4ce Switched build to use a larger runner (#2019) 2024-08-01 14:29:45 -07:00
hagen-danswer
e6a92aa936 support confluence single page only indexing (#2008)
* added index recursively checkbox

* mypy fixes

* added migration to not break existing connectors
2024-08-01 20:32:46 +00:00
pablodanswer
a54ea9f9fa Fix cartesian issue with index attempts (#2015) 2024-08-01 10:25:25 -07:00
Yuhong Sun
73a92c046d Fix chunker (#2014) 2024-08-01 10:18:02 -07:00
pablodanswer
459bd46846 Add Prompt library (#1990) 2024-08-01 08:40:35 -07:00
Chris Weaver
445f7e70ba Fix image generation (#2009) 2024-08-01 00:27:02 -07:00
Yuhong Sun
ca893f9918 Rerank Handle Null (#2010) 2024-07-31 22:59:02 -07:00
hagen-danswer
1be1959d80 Changed default local model to nomic (#1943) 2024-07-31 18:54:02 -07:00
Chris Weaver
1654378850 Fix user dropdown font (#2007) 2024-08-01 00:29:14 +00:00
Chris Weaver
d6d391d244 Fix not_applicable (#2003) 2024-07-31 21:30:07 +00:00
rkuo-danswer
7c283b090d Feature/postgres connection names (#1998)
* avoid reindexing secondary indexes after they succeed

* use postgres application names to facilitate connection debugging

* centralize all postgres application_name constants in the constants file

* missed a couple of files

* mypy fixes

* update dev background script
2024-07-31 20:36:30 +00:00
pablodanswer
40226678af Add proper default values for assistant editing / creation (#2001) 2024-07-31 13:34:42 -07:00
rkuo-danswer
288e6fa606 Bugfix/pg connections (#2002)
* increase max_connections to 150 in all docker files

* lower celery worker concurrency to 6
2024-07-31 19:49:20 +00:00
hagen-danswer
5307d38472 Fixed tokenizer logic (#1986) 2024-07-31 09:59:45 -07:00
Yuhong Sun
d619602a6f Skip shortcut docs (#1999) 2024-07-31 09:51:01 -07:00
Yuhong Sun
348a2176f0 Fix Dropped Documents (#1997) 2024-07-31 09:33:36 -07:00
pablodanswer
89b6da36a6 process files with null title (#1989) 2024-07-31 08:18:50 -07:00
Yuhong Sun
036d5c737e No Null Embeddings (#1982) 2024-07-30 19:54:49 -07:00
pablodanswer
60a87d9472 Add back modals on chat page (#1983) 2024-07-30 17:42:59 -07:00
pablodanswer
eb9bb56829 Add initial mobile support (#1962) 2024-07-30 17:13:50 -07:00
hagen-danswer
d151082871 Moved warmup_encoders into scope (#1978) 2024-07-30 16:37:32 +00:00
pablodanswer
e4b1f5b963 fix index attempt migration where no credential ID 2024-07-30 08:57:57 -07:00
hagen-danswer
3938a053aa Rework tokenizer (#1957) 2024-07-29 23:01:49 -07:00
pablodanswer
7932e764d6 Make chat page layout cleaner + fix updating assistant images (#1973)
* ux updates for clarity
- [x] 'folders' -> 'chat folders'
- [x] sidebar to bottom left and smaller
- [x] Sidebar -> smaller logo
- [x] Align things properly
- [x] Expliti Pin: immediate + "Pin / Unpin"
- [x] Logo size smaller
- [x] Align things properly
- [x] Optionally fix gradient in sidebar
- [x] Upload logo to existing assistants

* remove unneeded logs

* run pretty

* actually run pretty!

* fix web file type

* fix very minor typo

* clean type for buildPersonaAPIBody

* fix span formatting

* HUGE ui change
2024-07-30 03:44:35 +00:00
Chris Weaver
fb6695a983 Fix flow where oidc_expiry is different from token expiry (#1974) 2024-07-30 03:17:08 +00:00
rkuo-danswer
015f415b71 avoid reindexing secondary indexes after they succeed (#1971) 2024-07-30 03:12:58 +00:00
rkuo-danswer
96b582070b authorized users and groups only have read access (#1960)
* authorized users and groups only have read access

* slightly better variable naming
2024-07-29 19:53:42 +00:00
rkuo-danswer
4a0a927a64 fix removed parameter in MediaWikiConnector (#1970) 2024-07-29 18:47:30 +00:00
hagen-danswer
ea9a9cb553 Fix typing for previous message 2024-07-29 10:01:38 -07:00
pablodanswer
38af12ab97 remove unnecessary index drop (#1968) 2024-07-29 09:51:53 -07:00
hagen-danswer
1b3154188d Fixed default indexing frequency (#1965)
* Fixed default indexing frequency

* fixed more defaults
2024-07-29 08:14:49 -07:00
Weves
1f321826ad Bigger images 2024-07-28 23:47:06 -07:00
Weves
cbfbe4e5d8 Fix image generation follow up q 2024-07-28 23:47:06 -07:00
pablodanswer
3aa0e0124b Add new admin page (#1947)
* add admin page

* credential + typing fix

* rebase fix

* on add, cleaner buttons

* functional G + Ddrive

* organized auth sections

* update types and remove logs

* ccs -> connectors

* validated formik

* update styling + connector-handling logic

* udpate colors

* separate out hooks + util functions

* update to adhere to rest standards

* remove "todos"

* rebase

* copy + formatting + sidebar

* update statuses + configuration possibilities

* update interfaces to be clearer

* update indexing status page

* formatting

* address backend security + comments

* update font

* fix form routing

* fix hydration error

* add statuses, fix bugs, etc. (squash)

* fix color (squash)

* squash

* add functionality to sidebar

* disblae buttons if deleting

* add color

* minor copy + formatting updates
- on modify credential, close
- update copy for deletion of connectors

* fix build error

* copy

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2024-07-28 20:57:43 -07:00
Yuhong Sun
f2f60c9cc0 Fix EE Import backoff Logic (#1959) 2024-07-27 11:06:11 -07:00
Emerson Gomes
6c32821ad4 Allow removal of max_output_tokens by setting GEN_AI_MAX_OUTPUT_TOKENS=0 (#1958)
Co-authored-by: Emerson Gomes <emerson.gomes@thalesgroup.com>
2024-07-27 09:07:29 -07:00
Weves
d839595330 Add query override 2024-07-26 17:40:21 -07:00
Yuhong Sun
e422f96dff Pull Request Template (#1956) 2024-07-26 17:34:05 -07:00
Weves
d28f460330 Fix black 2024-07-26 16:43:15 -07:00
Eugene Astroner
8e441d975d Issue fix 2024-07-26 16:40:31 -07:00
pablodanswer
5c78af1f07 Deduplicate model names (#1950) 2024-07-26 16:30:49 -07:00
rkuo-danswer
e325e063ed Bugfix/persona access (#1951)
* also allow access to a persona if the user is in the list of authorized users or groups

* add comment on potential performance improvements

* work around for mypy typing
2024-07-26 22:05:57 +00:00
pablodanswer
c81b45300b Configurable models + updated assistants bar (#1942) 2024-07-26 11:00:49 -07:00
pablodanswer
26a1e963d1 Update personas.yaml (#1948) 2024-07-25 20:35:49 -07:00
pablodanswer
2a983263c7 Small update- Danswer update icons as well (#1945) 2024-07-25 20:31:41 -07:00
Yuhong Sun
2a37c95a5e Types for Migrations (#1944) 2024-07-25 18:18:48 -07:00
pablodanswer
c277a74f82 Add icons to assistants! (#1930) 2024-07-25 18:02:39 -07:00
rkuo-danswer
e4b31cd0d9 allow setting secondary worker count via environment variable. default to primary worker count if unset. (#1941) 2024-07-25 20:25:43 +00:00
hagen-danswer
a40d2a1e2e Change the way we get sqlalchemy session (#1940)
* changed default fast model to gpt-4o-mini

* Changed the way we get the sqlalchemy session
2024-07-25 18:36:14 +00:00
hagen-danswer
c9fb99d719 changed default fast model to gpt-4o-mini (#1939) 2024-07-25 10:50:02 -07:00
hagen-danswer
a4d71e08aa Added check for unknown tool names (#1924)
* answer.py

* Let it continue if broken
2024-07-25 00:19:08 +00:00
rkuo-danswer
546bfbd24b autoscale with pool=thread crashes celery. remove and use concurrency… (#1929)
* autoscale with pool=thread crashes celery. remove and use concurrency instead (to be improved later)

* update dev background script as well
2024-07-25 00:15:27 +00:00
hagen-danswer
27824d6cc6 Fixed login issue (#1920)
* included check for existing emails

* cleaned up logic
2024-07-25 00:03:29 +00:00
Weves
9d5c4ad634 Small fix for non tool calling LLMs 2024-07-24 15:41:43 -07:00
Shukant Pal
9b32003816 Handle SSL error tracebacks in site indexing connector (#1911)
My website (https://shukantpal.com) uses Let's Encrypt certificates, which aren't accepted by the Python urllib certificate verifier for some reason. My website is set up correctly otherwise (https://www.sslshopper.com/ssl-checker.html#hostname=www.shukantpal.com)

This change adds a fix so the correct traceback is shown in Danswer, instead of a generic "unable to connect, check your Internet connection".
2024-07-24 22:36:29 +00:00
pablodanswer
8bc4123ed7 add modern health check banner + expiration tracking (#1730)
---------

Co-authored-by: Weves <chrisweaver101@gmail.com>
2024-07-24 15:34:22 -07:00
pablodanswer
d58aaf7a59 add href 2024-07-24 14:33:56 -07:00
pablodanswer
a0056a1b3c add files (images) (#1926) 2024-07-24 21:26:01 +00:00
pablodanswer
d2584c773a slightly clearer description of model settings in assistants creation tab (#1925) 2024-07-24 21:25:30 +00:00
pablodanswer
807bef8ada Add environment variable for defaulted sidebar toggling (#1923)
* add env variable for defaulted sidebar toggling

* formatting

* update naming
2024-07-24 21:23:37 +00:00
rkuo-danswer
5afddacbb2 order list of new attempts from oldest to newest to prevent connector starvation (#1918) 2024-07-24 21:02:20 +00:00
hagen-danswer
4fb6a88f1e Quick fix (#1919) 2024-07-24 11:56:14 -07:00
rkuo-danswer
7057be6a88 Bugfix/indexing progress (#1916)
* mark in progress should always be committed

* no_commit version of mark_attempt is not needed
2024-07-24 11:39:44 -07:00
Yuhong Sun
91be8e7bfb Skip Null Docs (#1917) 2024-07-24 11:31:33 -07:00
Yuhong Sun
9651ea828b Handling Metadata by Vector and Keyword (#1909) 2024-07-24 11:05:56 -07:00
rkuo-danswer
6ee74bd0d1 fix pointers to various background tasks and scripts (#1914) 2024-07-24 10:12:51 -07:00
pablodanswer
48a0d29a5c Fix empty / reverted embeddings (#1910) 2024-07-23 22:41:31 -07:00
hagen-danswer
6ff8e6c0ea Improve eval pipeline qol (#1908) 2024-07-23 17:16:34 -07:00
Yuhong Sun
2470c68506 Don't rephrase first chat query (#1907) 2024-07-23 16:20:11 -07:00
hagen-danswer
866bc803b1 Implemented LLM disabling for api call (#1905) 2024-07-23 16:12:51 -07:00
pablodanswer
9c6084bd0d Embeddings- Clean up modal + "Important" call out (#1903) 2024-07-22 21:29:22 -07:00
hagen-danswer
a0b46c60c6 Switched eval api target back to oneshotqa (#1902) 2024-07-22 20:55:18 -07:00
pablodanswer
4029233df0 hide incomplete sources for non-admins (#1901) 2024-07-22 13:40:11 -07:00
hagen-danswer
6c88c0156c Added file upload retry logic (#1889) 2024-07-22 13:13:22 -07:00
pablodanswer
33332d08f2 fix citation title (#1900)
* fix citation title

* remove title function
2024-07-22 17:37:04 +00:00
hagen-danswer
17005fb705 switched default pruning behavior and removed some logging (#1898) 2024-07-22 17:36:26 +00:00
hagen-danswer
48a7fe80b1 Committed LLM updates to db (#1899) 2024-07-22 10:30:24 -07:00
pablodanswer
1276732409 Misc bug fixes (#1895) 2024-07-22 10:22:43 -07:00
Weves
f91b92a898 Make is_public default true for LLMProvider 2024-07-21 22:22:37 -07:00
Weves
6222f533be Update force delete script to handle user groups 2024-07-21 22:22:37 -07:00
hagen-danswer
1b49d17239 Added ability to control LLM access based on group (#1870)
* Added ability to control LLM access based on group

* completed relationship deletion

* cleaned up function

* added comments

* fixed frontend strings

* mypy fixes

* added case handling for deletion of user groups

* hidden advanced options now

* removed unnecessary code
2024-07-22 04:31:44 +00:00
Yuhong Sun
2f5f19642e Double Check Max Tokens for Indexing (#1893) 2024-07-21 21:12:39 -07:00
Yuhong Sun
6db4634871 Token Truncation (#1892) 2024-07-21 16:26:32 -07:00
Yuhong Sun
5cfed45cef Handle Empty Titles (#1891) 2024-07-21 14:59:23 -07:00
Weves
581ffde35a Fix jira connector failures for server deployments 2024-07-21 14:44:25 -07:00
pablodanswer
6313e6d91d Remove visit api when unneded (#1885)
* quick fix to test on ec2

* quick cleanup

* modify a name

* address full doc as well

* additional timing info + handling

* clean up

* squash

* Print only
2024-07-21 20:57:24 +00:00
Weves
c09c94bf32 Fix assistant swap 2024-07-21 13:57:36 -07:00
Yuhong Sun
0e8ba111c8 Model Touchups (#1887) 2024-07-21 12:31:00 -07:00
Yuhong Sun
2ba24b1734 Reenable Search Pipeline (#1886) 2024-07-21 10:33:29 -07:00
Yuhong Sun
44820b4909 k 2024-07-21 10:27:57 -07:00
hagen-danswer
eb3e7610fc Added retries and multithreading for cloud embedding (#1879)
* added retries and multithreading for cloud embedding

* refactored a bit

* cleaned up code

* got the errors to bubble up to the ui correctly

* added exceptin printing

* added requirements

* touchups

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2024-07-20 22:10:18 -07:00
pablodanswer
7fbbb174bb minor fixes (#1882)
- Assistants tab size
- Fixed logo -> absolute
2024-07-20 21:02:57 -07:00
pablodanswer
3854ca11af add newlines for message content 2024-07-20 18:57:29 -07:00
406 changed files with 16904 additions and 16079 deletions

25
.github/pull_request_template.md vendored Normal file
View File

@@ -0,0 +1,25 @@
## Description
[Provide a brief description of the changes in this PR]
## How Has This Been Tested?
[Describe the tests you ran to verify your changes]
## Accepted Risk
[Any know risks or failure modes to point out to reviewers]
## Related Issue(s)
[If applicable, link to the issue(s) this PR addresses]
## Checklist:
- [ ] All of the automated tests pass
- [ ] All PR comments are addressed and marked resolved
- [ ] If there are migrations, they have been rebased to latest main
- [ ] If there are new dependencies, they are added to the requirements
- [ ] If there are new environment variables, they are added to all of the deployment methods
- [ ] If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
- [ ] Docker images build and basic functionalities work
- [ ] Author has done a final read through of the PR right before merge

View File

@@ -7,7 +7,8 @@ on:
jobs:
build-and-push:
runs-on: ubuntu-latest
runs-on:
group: amd64-image-builders
steps:
- name: Checkout code

View File

@@ -15,7 +15,7 @@ LOG_LEVEL=debug
# This passes top N results to LLM an additional time for reranking prior to answer generation
# This step is quite heavy on token usage so we disable it for dev generally
DISABLE_LLM_CHUNK_FILTER=True
DISABLE_LLM_DOC_RELEVANCE=True
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)

View File

@@ -68,7 +68,9 @@ RUN apt-get update && \
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
# Pre-downloading models for setups with limited egress
RUN python -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('intfloat/e5-base-v2')"
RUN python -c "from tokenizers import Tokenizer; \
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
# Pre-downloading NLTK for setups with limited egress
RUN python -c "import nltk; \

View File

@@ -18,14 +18,17 @@ RUN apt-get remove -y --allow-remove-essential perl-base && \
apt-get autoremove -y
# Pre-downloading models for setups with limited egress
RUN python -c "from transformers import AutoModel, AutoTokenizer, TFDistilBertForSequenceClassification; \
RUN python -c "from transformers import AutoTokenizer; \
AutoTokenizer.from_pretrained('danswer/intent-model', cache_folder='/root/.cache/temp_huggingface/hub/'); \
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1', cache_folder='/root/.cache/temp_huggingface/hub/'); \
from transformers import TFDistilBertForSequenceClassification; \
TFDistilBertForSequenceClassification.from_pretrained('danswer/intent-model', cache_dir='/root/.cache/temp_huggingface/hub/'); \
from huggingface_hub import snapshot_download; \
AutoTokenizer.from_pretrained('danswer/intent-model'); \
AutoTokenizer.from_pretrained('intfloat/e5-base-v2'); \
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
snapshot_download('danswer/intent-model'); \
snapshot_download('intfloat/e5-base-v2'); \
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1')"
snapshot_download('danswer/intent-model', cache_dir='/root/.cache/temp_huggingface/hub/'); \
snapshot_download('nomic-ai/nomic-embed-text-v1', cache_dir='/root/.cache/temp_huggingface/hub/'); \
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1', cache_dir='/root/.cache/temp_huggingface/hub/'); \
from sentence_transformers import SentenceTransformer; \
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True, cache_folder='/root/.cache/temp_huggingface/hub/');"
WORKDIR /app

View File

@@ -17,15 +17,11 @@ depends_on: None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"chat_session",
sa.Column("current_alternate_model", sa.String(), nullable=True),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("chat_session", "current_alternate_model")
# ### end Alembic commands ###

View File

@@ -0,0 +1,26 @@
"""add_indexing_start_to_connector
Revision ID: 08a1eda20fe1
Revises: 8a87bd6ec550
Create Date: 2024-07-23 11:12:39.462397
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "08a1eda20fe1"
down_revision = "8a87bd6ec550"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"connector", sa.Column("indexing_start", sa.DateTime(), nullable=True)
)
def downgrade() -> None:
op.drop_column("connector", "indexing_start")

View File

@@ -79,7 +79,7 @@ def downgrade() -> None:
)
op.create_foreign_key(
"document_retrieval_feedback__chat_message_fk",
"document_retrieval",
"document_retrieval_feedback",
"chat_message",
["chat_message_id"],
["id"],

View File

@@ -160,12 +160,28 @@ def downgrade() -> None:
nullable=False,
),
)
op.drop_constraint(
"fk_index_attempt_credential_id", "index_attempt", type_="foreignkey"
)
op.drop_constraint(
"fk_index_attempt_connector_id", "index_attempt", type_="foreignkey"
)
# Check if the constraint exists before dropping
conn = op.get_bind()
inspector = sa.inspect(conn)
constraints = inspector.get_foreign_keys("index_attempt")
if any(
constraint["name"] == "fk_index_attempt_credential_id"
for constraint in constraints
):
op.drop_constraint(
"fk_index_attempt_credential_id", "index_attempt", type_="foreignkey"
)
if any(
constraint["name"] == "fk_index_attempt_connector_id"
for constraint in constraints
):
op.drop_constraint(
"fk_index_attempt_connector_id", "index_attempt", type_="foreignkey"
)
op.drop_column("index_attempt", "credential_id")
op.drop_column("index_attempt", "connector_id")
op.drop_table("connector_credential_pair")

View File

@@ -0,0 +1,70 @@
"""Add icon_color and icon_shape to Persona
Revision ID: 325975216eb3
Revises: 91ffac7e65b3
Create Date: 2024-07-24 21:29:31.784562
"""
import random
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, column, select
# revision identifiers, used by Alembic.
revision = "325975216eb3"
down_revision = "91ffac7e65b3"
branch_labels: None = None
depends_on: None = None
colorOptions = [
"#FF6FBF",
"#6FB1FF",
"#B76FFF",
"#FFB56F",
"#6FFF8D",
"#FF6F6F",
"#6FFFFF",
]
# Function to generate a random shape ensuring at least 3 of the middle 4 squares are filled
def generate_random_shape() -> int:
center_squares = [12, 10, 6, 14, 13, 11, 7, 15]
center_fill = random.choice(center_squares)
remaining_squares = [i for i in range(16) if not (center_fill & (1 << i))]
random.shuffle(remaining_squares)
for i in range(10 - bin(center_fill).count("1")):
center_fill |= 1 << remaining_squares[i]
return center_fill
def upgrade() -> None:
op.add_column("persona", sa.Column("icon_color", sa.String(), nullable=True))
op.add_column("persona", sa.Column("icon_shape", sa.Integer(), nullable=True))
op.add_column("persona", sa.Column("uploaded_image_id", sa.String(), nullable=True))
persona = table(
"persona",
column("id", sa.Integer),
column("icon_color", sa.String),
column("icon_shape", sa.Integer),
)
conn = op.get_bind()
personas = conn.execute(select(persona.c.id))
for persona_id in personas:
random_color = random.choice(colorOptions)
random_shape = generate_random_shape()
conn.execute(
persona.update()
.where(persona.c.id == persona_id[0])
.values(icon_color=random_color, icon_shape=random_shape)
)
def downgrade() -> None:
op.drop_column("persona", "icon_shape")
op.drop_column("persona", "uploaded_image_id")
op.drop_column("persona", "icon_color")

View File

@@ -18,7 +18,6 @@ depends_on: None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True)
)
@@ -29,10 +28,8 @@ def upgrade() -> None:
["alternate_assistant_id"],
["id"],
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("fk_chat_message_persona", "chat_message", type_="foreignkey")
op.drop_column("chat_message", "alternate_assistant_id")

View File

@@ -0,0 +1,42 @@
"""Rename index_origin to index_recursively
Revision ID: 1d6ad76d1f37
Revises: e1392f05e840
Create Date: 2024-08-01 12:38:54.466081
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "1d6ad76d1f37"
down_revision = "e1392f05e840"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.execute(
"""
UPDATE connector
SET connector_specific_config = jsonb_set(
connector_specific_config,
'{index_recursively}',
'true'::jsonb
) - 'index_origin'
WHERE connector_specific_config ? 'index_origin'
"""
)
def downgrade() -> None:
op.execute(
"""
UPDATE connector
SET connector_specific_config = jsonb_set(
connector_specific_config,
'{index_origin}',
connector_specific_config->'index_recursively'
) - 'index_recursively'
WHERE connector_specific_config ? 'index_recursively'
"""
)

View File

@@ -0,0 +1,49 @@
"""Add display_model_names to llm_provider
Revision ID: 473a1a7ca408
Revises: 325975216eb3
Create Date: 2024-07-25 14:31:02.002917
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "473a1a7ca408"
down_revision = "325975216eb3"
branch_labels: None = None
depends_on: None = None
default_models_by_provider = {
"openai": ["gpt-4", "gpt-4o", "gpt-4o-mini"],
"bedrock": [
"meta.llama3-1-70b-instruct-v1:0",
"meta.llama3-1-8b-instruct-v1:0",
"anthropic.claude-3-opus-20240229-v1:0",
"mistral.mistral-large-2402-v1:0",
"anthropic.claude-3-5-sonnet-20240620-v1:0",
],
"anthropic": ["claude-3-opus-20240229", "claude-3-5-sonnet-20240620"],
}
def upgrade() -> None:
op.add_column(
"llm_provider",
sa.Column("display_model_names", postgresql.ARRAY(sa.String()), nullable=True),
)
connection = op.get_bind()
for provider, models in default_models_by_provider.items():
connection.execute(
sa.text(
"UPDATE llm_provider SET display_model_names = :models WHERE provider = :provider"
),
{"models": models, "provider": provider},
)
def downgrade() -> None:
op.drop_column("llm_provider", "display_model_names")

View File

@@ -0,0 +1,72 @@
"""Add type to credentials
Revision ID: 4ea2c93919c1
Revises: 473a1a7ca408
Create Date: 2024-07-18 13:07:13.655895
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4ea2c93919c1"
down_revision = "473a1a7ca408"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# Add the new 'source' column to the 'credential' table
op.add_column(
"credential",
sa.Column(
"source",
sa.String(length=100), # Use String instead of Enum
nullable=True, # Initially allow NULL values
),
)
op.add_column(
"credential",
sa.Column(
"name",
sa.String(),
nullable=True,
),
)
# Create a temporary table that maps each credential to a single connector source.
# This is needed because a credential can be associated with multiple connectors,
# but we want to assign a single source to each credential.
# We use DISTINCT ON to ensure we only get one row per credential_id.
op.execute(
"""
CREATE TEMPORARY TABLE temp_connector_credential AS
SELECT DISTINCT ON (cc.credential_id)
cc.credential_id,
c.source AS connector_source
FROM connector_credential_pair cc
JOIN connector c ON cc.connector_id = c.id
"""
)
# Update the 'source' column in the 'credential' table
op.execute(
"""
UPDATE credential cred
SET source = COALESCE(
(SELECT connector_source
FROM temp_connector_credential temp
WHERE cred.id = temp.credential_id),
'NOT_APPLICABLE'
)
"""
)
# If no exception was raised, alter the column
op.alter_column("credential", "source", nullable=True) # TODO modify
# # ### end Alembic commands ###
def downgrade() -> None:
op.drop_column("credential", "source")
op.drop_column("credential", "name")

View File

@@ -28,5 +28,9 @@ def upgrade() -> None:
def downgrade() -> None:
# This wasn't really required by the code either, no good reason to make it unique again
pass
op.create_unique_constraint(
"connector_credential_pair__name__key", "connector_credential_pair", ["name"]
)
op.alter_column(
"connector_credential_pair", "name", existing_type=sa.String(), nullable=True
)

View File

@@ -0,0 +1,41 @@
"""add_llm_group_permissions_control
Revision ID: 795b20b85b4b
Revises: 05c07bf07c00
Create Date: 2024-07-19 11:54:35.701558
"""
from alembic import op
import sqlalchemy as sa
revision = "795b20b85b4b"
down_revision = "05c07bf07c00"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"llm_provider__user_group",
sa.Column("llm_provider_id", sa.Integer(), nullable=False),
sa.Column("user_group_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["llm_provider_id"],
["llm_provider.id"],
),
sa.ForeignKeyConstraint(
["user_group_id"],
["user_group.id"],
),
sa.PrimaryKeyConstraint("llm_provider_id", "user_group_id"),
)
op.add_column(
"llm_provider",
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="true"),
)
def downgrade() -> None:
op.drop_table("llm_provider__user_group")
op.drop_column("llm_provider", "is_public")

View File

@@ -0,0 +1,103 @@
"""associate index attempts with ccpair
Revision ID: 8a87bd6ec550
Revises: 4ea2c93919c1
Create Date: 2024-07-22 15:15:52.558451
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8a87bd6ec550"
down_revision = "4ea2c93919c1"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# Add the new connector_credential_pair_id column
op.add_column(
"index_attempt",
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=True),
)
# Create a foreign key constraint to the connector_credential_pair table
op.create_foreign_key(
"fk_index_attempt_connector_credential_pair_id",
"index_attempt",
"connector_credential_pair",
["connector_credential_pair_id"],
["id"],
)
# Populate the new connector_credential_pair_id column using existing connector_id and credential_id
op.execute(
"""
UPDATE index_attempt ia
SET connector_credential_pair_id =
CASE
WHEN ia.credential_id IS NULL THEN
(SELECT id FROM connector_credential_pair
WHERE connector_id = ia.connector_id
LIMIT 1)
ELSE
(SELECT id FROM connector_credential_pair
WHERE connector_id = ia.connector_id
AND credential_id = ia.credential_id)
END
WHERE ia.connector_id IS NOT NULL
"""
)
# Make the new connector_credential_pair_id column non-nullable
op.alter_column("index_attempt", "connector_credential_pair_id", nullable=False)
# Drop the old connector_id and credential_id columns
op.drop_column("index_attempt", "connector_id")
op.drop_column("index_attempt", "credential_id")
# Update the index to use connector_credential_pair_id
op.create_index(
"ix_index_attempt_latest_for_connector_credential_pair",
"index_attempt",
["connector_credential_pair_id", "time_created"],
)
def downgrade() -> None:
# Add back the old connector_id and credential_id columns
op.add_column(
"index_attempt", sa.Column("connector_id", sa.Integer(), nullable=True)
)
op.add_column(
"index_attempt", sa.Column("credential_id", sa.Integer(), nullable=True)
)
# Populate the old connector_id and credential_id columns using the connector_credential_pair_id
op.execute(
"""
UPDATE index_attempt ia
SET connector_id = ccp.connector_id, credential_id = ccp.credential_id
FROM connector_credential_pair ccp
WHERE ia.connector_credential_pair_id = ccp.id
"""
)
# Make the old connector_id and credential_id columns non-nullable
op.alter_column("index_attempt", "connector_id", nullable=False)
op.alter_column("index_attempt", "credential_id", nullable=False)
# Drop the new connector_credential_pair_id column
op.drop_constraint(
"fk_index_attempt_connector_credential_pair_id",
"index_attempt",
type_="foreignkey",
)
op.drop_column("index_attempt", "connector_credential_pair_id")
op.create_index(
"ix_index_attempt_latest_for_connector_credential_pair",
"index_attempt",
["connector_id", "credential_id", "time_created"],
)

View File

@@ -0,0 +1,26 @@
"""add expiry time
Revision ID: 91ffac7e65b3
Revises: bc9771dccadf
Create Date: 2024-06-24 09:39:56.462242
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "91ffac7e65b3"
down_revision = "795b20b85b4b"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"user", sa.Column("oidc_expiry", sa.DateTime(timezone=True), nullable=True)
)
def downgrade() -> None:
op.drop_column("user", "oidc_expiry")

View File

@@ -16,7 +16,6 @@ depends_on: None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"connector_credential_pair",
"last_attempt_status",
@@ -29,11 +28,9 @@ def upgrade() -> None:
),
nullable=True,
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"connector_credential_pair",
"last_attempt_status",
@@ -46,4 +43,3 @@ def downgrade() -> None:
),
nullable=False,
)
# ### end Alembic commands ###

View File

@@ -19,6 +19,9 @@ depends_on: None = None
def upgrade() -> None:
op.drop_table("deletion_attempt")
# Remove the DeletionStatus enum
op.execute("DROP TYPE IF EXISTS deletionstatus;")
def downgrade() -> None:
op.create_table(

View File

@@ -136,4 +136,4 @@ def downgrade() -> None:
)
op.drop_column("index_attempt", "embedding_model_id")
op.drop_table("embedding_model")
op.execute("DROP TYPE indexmodelstatus;")
op.execute("DROP TYPE IF EXISTS indexmodelstatus;")

View File

@@ -0,0 +1,58 @@
"""Added input prompts
Revision ID: e1392f05e840
Revises: 08a1eda20fe1
Create Date: 2024-07-13 19:09:22.556224
"""
import fastapi_users_db_sqlalchemy
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "e1392f05e840"
down_revision = "08a1eda20fe1"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"inputprompt",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("prompt", sa.String(), nullable=False),
sa.Column("content", sa.String(), nullable=False),
sa.Column("active", sa.Boolean(), nullable=False),
sa.Column("is_public", sa.Boolean(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"inputprompt__user",
sa.Column("input_prompt_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["input_prompt_id"],
["inputprompt.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["inputprompt.id"],
),
sa.PrimaryKeyConstraint("input_prompt_id", "user_id"),
)
def downgrade() -> None:
op.drop_table("inputprompt__user")
op.drop_table("inputprompt")

View File

@@ -1,6 +1,8 @@
import smtplib
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime
from datetime import timezone
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Optional
@@ -50,8 +52,10 @@ from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.engine import get_session
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import AccessToken
from danswer.db.models import User
from danswer.db.users import get_user_by_email
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
@@ -92,12 +96,18 @@ def user_needs_to_be_verified() -> bool:
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
def verify_email_in_whitelist(email: str) -> None:
def verify_email_is_invited(email: str) -> None:
whitelist = get_invited_users()
if (whitelist and email not in whitelist) or not email:
raise PermissionError("User not on allowed user whitelist")
def verify_email_in_whitelist(email: str) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
if not get_user_by_email(email, db_session):
verify_email_is_invited(email)
def verify_email_domain(email: str) -> None:
if VALID_EMAIL_DOMAINS:
if email.count("@") != 1:
@@ -147,7 +157,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False,
request: Optional[Request] = None,
) -> models.UP:
verify_email_in_whitelist(user_create.email)
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if hasattr(user_create, "role"):
user_count = await get_user_count()
@@ -173,7 +183,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
verify_email_in_whitelist(account_email)
verify_email_domain(account_email)
return await super().oauth_callback( # type: ignore
user = await super().oauth_callback( # type: ignore
oauth_name=oauth_name,
access_token=access_token,
account_id=account_id,
@@ -185,6 +195,14 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
is_verified_by_default=is_verified_by_default,
)
# NOTE: google oauth expires after 1hr. We don't want to force the user to
# re-authenticate that frequently, so for now we'll just ignore this for
# google oauth users
if expires_at and AUTH_TYPE != AuthType.GOOGLE_OAUTH:
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
await self.user_db.update(user, update_dict={"oidc_expiry": oidc_expiry})
return user
async def on_after_register(
self, user: User, request: Optional[Request] = None
) -> None:
@@ -227,10 +245,12 @@ cookie_transport = CookieTransport(
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> DatabaseStrategy:
return DatabaseStrategy(
strategy = DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
)
return strategy
auth_backend = AuthenticationBackend(
name="database",
@@ -327,6 +347,12 @@ async def double_check_user(
detail="Access denied. User is not verified.",
)
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User's OIDC token has expired.",
)
return user
@@ -345,4 +371,5 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not an admin.",
)
return user

View File

@@ -14,6 +14,7 @@ from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import POSTGRES_CELERY_APP_NAME
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import InputType
from danswer.db.connector_credential_pair import get_connector_credential_pair
@@ -38,7 +39,9 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
connection_string = build_connection_string(db_api=SYNC_DB_API)
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=POSTGRES_CELERY_APP_NAME
)
celery_broker_url = f"sqla+{connection_string}"
celery_backend_url = f"db+{connection_string}"
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
@@ -100,7 +103,7 @@ def cleanup_connector_credential_pair_task(
@build_celery_task_wrapper(name_cc_prune_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def prune_documents_task(connector_id: int, credential_id: int) -> None:
"""connector pruning task. For a cc pair, this task pulls all docuement IDs from the source
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
with Session(get_sqlalchemy_engine()) as db_session:

View File

@@ -6,8 +6,8 @@ from sqlalchemy.orm import Session
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.configs.app_configs import PREVENT_SIMULTANEOUS_PRUNING
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
@@ -80,7 +80,7 @@ def should_prune_cc_pair(
return True
return False
if PREVENT_SIMULTANEOUS_PRUNING:
if not ALLOW_SIMULTANEOUS_PRUNING:
pruning_type_task_name = name_cc_prune_task()
last_pruning_type_task = get_latest_task_by_type(
pruning_type_task_name, db_session
@@ -89,11 +89,9 @@ def should_prune_cc_pair(
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
last_pruning_type_task, db_session
):
logger.info("Another Connector is already pruning. Skipping.")
return False
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
logger.info(f"Connector '{connector.name}' is already pruning. Skipping.")
return False
if not last_pruning_task.start_time:

View File

@@ -41,6 +41,12 @@ def _initializer(
return func(*args, **kwargs)
def _run_in_process(
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
) -> None:
_initializer(func, args, kwargs)
@dataclass
class SimpleJob:
"""Drop in replacement for `dask.distributed.Future`"""
@@ -113,7 +119,7 @@ class SimpleJobClient:
job_id = self.job_id_counter
self.job_id_counter += 1
process = Process(target=_initializer(func=func, args=args), daemon=True)
process = Process(target=_run_in_process, args=(func, args), daemon=True)
job = SimpleJob(id=job_id, process=process)
process.start()

View File

@@ -20,7 +20,7 @@ from danswer.db.connector_credential_pair import update_connector_credential_pai
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_in_progress__no_commit
from danswer.db.index_attempt import mark_attempt_in_progress
from danswer.db.index_attempt import mark_attempt_succeeded
from danswer.db.index_attempt import update_docs_indexed
from danswer.db.models import IndexAttempt
@@ -49,19 +49,19 @@ def _get_document_generator(
are the complete list of existing documents of the connector. If the task
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
"""
task = attempt.connector.input_type
task = attempt.connector_credential_pair.connector.input_type
try:
runnable_connector = instantiate_connector(
attempt.connector.source,
attempt.connector_credential_pair.connector.source,
task,
attempt.connector.connector_specific_config,
attempt.credential,
attempt.connector_credential_pair.connector.connector_specific_config,
attempt.connector_credential_pair.credential,
db_session,
)
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")
disable_connector(attempt.connector.id, db_session)
disable_connector(attempt.connector_credential_pair.connector.id, db_session)
raise e
if task == InputType.LOAD_STATE:
@@ -70,7 +70,10 @@ def _get_document_generator(
elif task == InputType.POLL:
assert isinstance(runnable_connector, PollConnector)
if attempt.connector_id is None or attempt.credential_id is None:
if (
attempt.connector_credential_pair.connector_id is None
or attempt.connector_credential_pair.connector_id is None
):
raise ValueError(
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
f"can't fetch time range."
@@ -127,16 +130,21 @@ def _run_indexing(
db_session=db_session,
)
db_connector = index_attempt.connector
db_credential = index_attempt.credential
db_connector = index_attempt.connector_credential_pair.connector
db_credential = index_attempt.connector_credential_pair.credential
last_successful_index_time = (
0.0
if index_attempt.from_beginning
else get_last_successful_attempt_time(
connector_id=db_connector.id,
credential_id=db_credential.id,
embedding_model=index_attempt.embedding_model,
db_session=db_session,
db_connector.indexing_start.timestamp()
if index_attempt.from_beginning and db_connector.indexing_start is not None
else (
0.0
if index_attempt.from_beginning
else get_last_successful_attempt_time(
connector_id=db_connector.id,
credential_id=db_credential.id,
embedding_model=index_attempt.embedding_model,
db_session=db_session,
)
)
)
@@ -189,7 +197,7 @@ def _run_indexing(
)
new_docs, total_batch_chunks = indexing_pipeline(
documents=doc_batch,
document_batch=doc_batch,
index_attempt_metadata=IndexAttemptMetadata(
connector_id=db_connector.id,
credential_id=db_credential.id,
@@ -250,8 +258,8 @@ def _run_indexing(
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=index_attempt.connector.id,
credential_id=index_attempt.credential.id,
connector_id=index_attempt.connector_credential_pair.connector.id,
credential_id=index_attempt.connector_credential_pair.credential.id,
net_docs=net_doc_change,
)
raise e
@@ -269,11 +277,9 @@ def _run_indexing(
run_dt=run_end_dt,
)
elapsed_time = time.time() - start_time
logger.info(
f"Indexed or refreshed {document_count} total documents for a total of {chunk_count} indexed chunks"
)
logger.info(
f"Connector successfully finished, elapsed time: {time.time() - start_time} seconds"
f"Connector succeeded: docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
)
@@ -299,9 +305,7 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
)
# only commit once, to make sure this all happens in a single transaction
mark_attempt_in_progress__no_commit(attempt)
if attempt.embedding_model.status != IndexModelStatus.PRESENT:
db_session.commit()
mark_attempt_in_progress(attempt, db_session)
return attempt
@@ -324,17 +328,19 @@ def run_indexing_entrypoint(index_attempt_id: int, is_ee: bool = False) -> None:
attempt = _prepare_index_attempt(db_session, index_attempt_id)
logger.info(
f"Running indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
f"Indexing starting: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
_run_indexing(db_session, attempt)
logger.info(
f"Completed indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
f"Indexing finished: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
except Exception as e:
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")

View File

@@ -16,15 +16,19 @@ from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
from danswer.db.connector import fetch_connectors
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import init_sqlalchemy_engine
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_inprogress_index_attempts
from danswer.db.index_attempt import get_last_attempt
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import get_not_started_index_attempts
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import Connector
@@ -33,7 +37,7 @@ from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.db.swap_index import check_index_swap
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
@@ -66,28 +70,46 @@ def _should_create_new_indexing(
return False
# When switching over models, always index at least once
if model.status == IndexModelStatus.FUTURE and not last_index:
if connector.id == 0: # Ingestion API
return False
if model.status == IndexModelStatus.FUTURE:
if last_index:
# No new index if the last index attempt succeeded
# Once is enough. The model will never be able to swap otherwise.
if last_index.status == IndexingStatus.SUCCESS:
return False
# No new index if the last index attempt is waiting to start
if last_index.status == IndexingStatus.NOT_STARTED:
return False
# No new index if the last index attempt is running
if last_index.status == IndexingStatus.IN_PROGRESS:
return False
else:
if connector.id == 0: # Ingestion API
return False
return True
# If the connector is disabled, don't index
# NOTE: during an embedding model switch over, we ignore this
# and index the disabled connectors as well (which is why this if
# statement is below the first condition above)
if connector.disabled:
# If the connector is disabled or is the ingestion API, don't index
# NOTE: during an embedding model switch over, the following logic
# is bypassed by the above check for a future model
if connector.disabled or connector.id == 0:
return False
if connector.refresh_freq is None:
return False
if not last_index:
return True
# Only one scheduled job per connector at a time
# Can schedule another one if the current one is already running however
# Because the currently running one will not be until the latest time
# Note, this last index is for the given embedding model
if last_index.status == IndexingStatus.NOT_STARTED:
if connector.refresh_freq is None:
return False
# Only one scheduled/ongoing job per connector at a time
# this prevents cases where
# (1) the "latest" index_attempt is scheduled so we show
# that in the UI despite another index_attempt being in-progress
# (2) multiple scheduled index_attempts at a time
if (
last_index.status == IndexingStatus.NOT_STARTED
or last_index.status == IndexingStatus.IN_PROGRESS
):
return False
current_db_time = get_db_current_time(db_session)
@@ -111,8 +133,8 @@ def _mark_run_failed(
"""Marks the `index_attempt` row as failed + updates the `
connector_credential_pair` to reflect that the run failed"""
logger.warning(
f"Marking in-progress attempt 'connector: {index_attempt.connector_id}, "
f"credential: {index_attempt.credential_id}' as failed due to {failure_reason}"
f"Marking in-progress attempt 'connector: {index_attempt.connector_credential_pair.connector_id}, "
f"credential: {index_attempt.connector_credential_pair.credential_id}' as failed due to {failure_reason}"
)
mark_attempt_failed(
index_attempt=index_attempt,
@@ -131,7 +153,7 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
3. There is not already an ongoing indexing attempt for this pair
"""
with Session(get_sqlalchemy_engine()) as db_session:
ongoing: set[tuple[int | None, int | None, int]] = set()
ongoing: set[tuple[int | None, int]] = set()
for attempt_id in existing_jobs:
attempt = get_index_attempt(
db_session=db_session, index_attempt_id=attempt_id
@@ -144,8 +166,7 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
continue
ongoing.add(
(
attempt.connector_id,
attempt.credential_id,
attempt.connector_credential_pair_id,
attempt.embedding_model_id,
)
)
@@ -155,31 +176,26 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
if secondary_embedding_model is not None:
embedding_models.append(secondary_embedding_model)
all_connectors = fetch_connectors(db_session)
for connector in all_connectors:
for association in connector.credentials:
for model in embedding_models:
credential = association.credential
all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair in all_connector_credential_pairs:
for model in embedding_models:
# Check if there is an ongoing indexing attempt for this connector credential pair
if (cc_pair.id, model.id) in ongoing:
continue
# Check if there is an ongoing indexing attempt for this connector + credential pair
if (connector.id, credential.id, model.id) in ongoing:
continue
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, model.id, db_session
)
if not _should_create_new_indexing(
connector=cc_pair.connector,
last_index=last_attempt,
model=model,
secondary_index_building=len(embedding_models) > 1,
db_session=db_session,
):
continue
last_attempt = get_last_attempt(
connector.id, credential.id, model.id, db_session
)
if not _should_create_new_indexing(
connector=connector,
last_index=last_attempt,
model=model,
secondary_index_building=len(embedding_models) > 1,
db_session=db_session,
):
continue
create_index_attempt(
connector.id, credential.id, model.id, db_session
)
create_index_attempt(cc_pair.id, model.id, db_session)
def cleanup_indexing_jobs(
@@ -271,24 +287,28 @@ def kickoff_indexing_jobs(
# Don't include jobs waiting in the Dask queue that just haven't started running
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
with Session(engine) as db_session:
# get_not_started_index_attempts orders its returned results from oldest to newest
# we must process attempts in a FIFO manner to prevent connector starvation
new_indexing_attempts = [
(attempt, attempt.embedding_model)
for attempt in get_not_started_index_attempts(db_session)
if attempt.id not in existing_jobs
]
logger.info(f"Found {len(new_indexing_attempts)} new indexing tasks.")
logger.debug(f"Found {len(new_indexing_attempts)} new indexing task(s).")
if not new_indexing_attempts:
return existing_jobs
indexing_attempt_count = 0
for attempt, embedding_model in new_indexing_attempts:
use_secondary_index = (
embedding_model.status == IndexModelStatus.FUTURE
if embedding_model is not None
else False
)
if attempt.connector is None:
if attempt.connector_credential_pair.connector is None:
logger.warning(
f"Skipping index attempt as Connector has been deleted: {attempt}"
)
@@ -297,7 +317,7 @@ def kickoff_indexing_jobs(
attempt, db_session, failure_reason="Connector is null"
)
continue
if attempt.credential is None:
if attempt.connector_credential_pair.credential is None:
logger.warning(
f"Skipping index attempt as Credential has been deleted: {attempt}"
)
@@ -323,35 +343,52 @@ def kickoff_indexing_jobs(
)
if run:
secondary_str = "(secondary index) " if use_secondary_index else ""
if indexing_attempt_count == 0:
logger.info(
f"Indexing dispatch starts: pending={len(new_indexing_attempts)}"
)
indexing_attempt_count += 1
secondary_str = " (secondary index)" if use_secondary_index else ""
logger.info(
f"Kicked off {secondary_str}"
f"indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
f"Indexing dispatched{secondary_str}: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.credential_id}'"
)
existing_jobs_copy[attempt.id] = run
if indexing_attempt_count > 0:
logger.info(
f"Indexing dispatch results: "
f"initial_pending={len(new_indexing_attempts)} "
f"started={indexing_attempt_count} "
f"remaining={len(new_indexing_attempts) - indexing_attempt_count}"
)
return existing_jobs_copy
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
def update_loop(
delay: int = 10,
num_workers: int = NUM_INDEXING_WORKERS,
num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
) -> None:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
check_index_swap(db_session=db_session)
db_embedding_model = get_current_db_embedding_model(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
if db_embedding_model.cloud_provider_id is None:
logger.info("Running a first inference to warm up embedding model")
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=INDEXING_MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
if db_embedding_model.cloud_provider_id is None:
logger.debug("Running a first inference to warm up embedding model")
warm_up_encoders(
embedding_model=db_embedding_model,
model_server_host=INDEXING_MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
client_primary: Client | SimpleJobClient
client_secondary: Client | SimpleJobClient
@@ -366,7 +403,7 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
silence_logs=logging.ERROR,
)
cluster_secondary = LocalCluster(
n_workers=num_workers,
n_workers=num_secondary_workers,
threads_per_worker=1,
silence_logs=logging.ERROR,
)
@@ -376,18 +413,18 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
client_primary.register_worker_plugin(ResourceLogger())
else:
client_primary = SimpleJobClient(n_workers=num_workers)
client_secondary = SimpleJobClient(n_workers=num_workers)
client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
existing_jobs: dict[int, Future | SimpleJob] = {}
while True:
start = time.time()
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
logger.info(f"Running update, current UTC time: {start_time_utc}")
logger.debug(f"Running update, current UTC time: {start_time_utc}")
if existing_jobs:
# TODO: make this debug level once the "no jobs are being scheduled" issue is resolved
logger.info(
logger.debug(
"Found existing indexing jobs: "
f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}"
)
@@ -411,8 +448,9 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
def update__main() -> None:
set_is_ee_based_on_env_variable()
init_sqlalchemy_engine(POSTGRES_INDEXER_APP_NAME)
logger.info("Starting Indexing Loop")
logger.info("Starting indexing service")
update_loop()

View File

@@ -35,6 +35,7 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
def create_chat_chain(
chat_session_id: int,
db_session: Session,
prefetch_tool_calls: bool = True,
) -> tuple[ChatMessage, list[ChatMessage]]:
"""Build the linear chain of messages without including the root message"""
mainline_messages: list[ChatMessage] = []
@@ -43,6 +44,7 @@ def create_chat_chain(
user_id=None,
db_session=db_session,
skip_permission_check=True,
prefetch_tool_calls=prefetch_tool_calls,
)
id_to_msg = {msg.id: msg for msg in all_chat_messages}

View File

@@ -0,0 +1,24 @@
input_prompts:
- id: -5
prompt: "Elaborate"
content: "Elaborate on the above, give me a more in depth explanation."
active: true
is_public: true
- id: -4
prompt: "Reword"
content: "Help me rewrite the following politely and concisely for professional communication:\n"
active: true
is_public: true
- id: -3
prompt: "Email"
content: "Write a professional email for me including a subject line, signature, etc. Template the parts that need editing with [ ]. The email should cover the following points:\n"
active: true
is_public: true
- id: -2
prompt: "Debug"
content: "Provide step-by-step troubleshooting instructions for the following issue:\n"
active: true
is_public: true

View File

@@ -1,13 +1,17 @@
import yaml
from sqlalchemy.orm import Session
from danswer.configs.chat_configs import INPUT_PROMPT_YAML
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import PERSONAS_YAML
from danswer.configs.chat_configs import PROMPTS_YAML
from danswer.db.document_set import get_or_create_document_set_by_name
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.db.models import Persona
from danswer.db.models import Prompt as PromptDBModel
from danswer.db.models import Tool as ToolDBModel
from danswer.db.persona import get_prompt_by_name
from danswer.db.persona import upsert_persona
from danswer.db.persona import upsert_prompt
@@ -76,9 +80,31 @@ def load_personas_from_yaml(
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
p_id = persona.get("id")
tool_ids = []
if persona.get("image_generation"):
image_gen_tool = (
db_session.query(ToolDBModel)
.filter(ToolDBModel.name == "ImageGenerationTool")
.first()
)
if image_gen_tool:
tool_ids.append(image_gen_tool.id)
llm_model_provider_override = persona.get("llm_model_provider_override")
llm_model_version_override = persona.get("llm_model_version_override")
# Set specific overrides for image generation persona
if persona.get("image_generation"):
llm_model_version_override = "gpt-4o"
existing_persona = (
db_session.query(Persona)
.filter(Persona.name == persona["name"])
.first()
)
upsert_persona(
user=None,
# Negative to not conflict with existing personas
persona_id=(-1 * p_id) if p_id is not None else None,
name=persona["name"],
description=persona["description"],
@@ -88,20 +114,52 @@ def load_personas_from_yaml(
llm_relevance_filter=persona.get("llm_relevance_filter"),
starter_messages=persona.get("starter_messages"),
llm_filter_extraction=persona.get("llm_filter_extraction"),
llm_model_provider_override=None,
llm_model_version_override=None,
icon_shape=persona.get("icon_shape"),
icon_color=persona.get("icon_color"),
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
prompt_ids=prompt_ids,
document_set_ids=doc_set_ids,
tool_ids=tool_ids,
default_persona=True,
is_public=True,
display_priority=existing_persona.display_priority
if existing_persona is not None
else persona.get("display_priority"),
is_visible=existing_persona.is_visible
if existing_persona is not None
else persona.get("is_visible"),
db_session=db_session,
)
def load_input_prompts_from_yaml(input_prompts_yaml: str = INPUT_PROMPT_YAML) -> None:
with open(input_prompts_yaml, "r") as file:
data = yaml.safe_load(file)
all_input_prompts = data.get("input_prompts", [])
with Session(get_sqlalchemy_engine()) as db_session:
for input_prompt in all_input_prompts:
# If these prompts are deleted (which is a hard delete in the DB), on server startup
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
insert_input_prompt_if_not_exists(
user=None,
input_prompt_id=input_prompt.get("id"),
prompt=input_prompt["prompt"],
content=input_prompt["content"],
is_public=input_prompt["is_public"],
active=input_prompt.get("active", True),
db_session=db_session,
commit=True,
)
def load_chat_yamls(
prompt_yaml: str = PROMPTS_YAML,
personas_yaml: str = PERSONAS_YAML,
input_prompts_yaml: str = INPUT_PROMPT_YAML,
) -> None:
load_prompts_from_yaml(prompt_yaml)
load_personas_from_yaml(personas_yaml)
load_input_prompts_from_yaml(input_prompts_yaml)

View File

@@ -46,15 +46,22 @@ class LLMRelevanceFilterResponse(BaseModel):
relevant_chunk_indices: list[int]
class RelevanceChunk(BaseModel):
# TODO make this document level. Also slight misnomer here as this is actually
# done at the section level currently rather than the chunk
relevant: bool | None = None
class RelevanceAnalysis(BaseModel):
relevant: bool
content: str | None = None
class LLMRelevanceSummaryResponse(BaseModel):
relevance_summaries: dict[str, RelevanceChunk]
class SectionRelevancePiece(RelevanceAnalysis):
"""LLM analysis mapped to an Inference Section"""
document_id: str
chunk_id: int # ID of the center chunk for a given inference section
class DocumentRelevance(BaseModel):
"""Contains all relevance information for a given search"""
relevance_summaries: dict[str, RelevanceAnalysis]
class DanswerAnswerPiece(BaseModel):

View File

@@ -5,7 +5,7 @@ personas:
# this is for DanswerBot to use when tagged in a non-configured channel
# Careful setting specific IDs, this won't autoincrement the next ID value for postgres
- id: 0
name: "Danswer"
name: "Knowledge"
description: >
Assistant with access to documents from your Connected Sources.
# Default Prompt objects attached to the persona, see prompts.yaml
@@ -17,7 +17,7 @@ personas:
num_chunks: 10
# Enable/Disable usage of the LLM chunk filter feature whereby each chunk is passed to the LLM to determine
# if the chunk is useful or not towards the latest user query
# This feature can be overriden for all personas via DISABLE_LLM_CHUNK_FILTER env variable
# This feature can be overriden for all personas via DISABLE_LLM_DOC_RELEVANCE env variable
llm_relevance_filter: true
# Enable/Disable usage of the LLM to extract query time filters including source type and time range filters
llm_filter_extraction: true
@@ -37,12 +37,15 @@ personas:
# - "Engineer Onboarding"
# - "Benefits"
document_sets: []
icon_shape: 23013
icon_color: "#6FB1FF"
display_priority: 1
is_visible: true
- id: 1
name: "GPT"
name: "General"
description: >
Assistant with no access to documents. Chat with just the Language Model.
Assistant with no access to documents. Chat with just the Large Language Model.
prompts:
- "OnlyLLM"
num_chunks: 0
@@ -50,7 +53,10 @@ personas:
llm_filter_extraction: true
recency_bias: "auto"
document_sets: []
icon_shape: 50910
icon_color: "#FF6F6F"
display_priority: 0
is_visible: true
- id: 2
name: "Paraphrase"
@@ -63,3 +69,25 @@ personas:
llm_filter_extraction: true
recency_bias: "auto"
document_sets: []
icon_shape: 45519
icon_color: "#6FFF8D"
display_priority: 2
is_visible: false
- id: 3
name: "Art"
description: >
Assistant for generating images based on descriptions.
prompts:
- "ImageGeneration"
num_chunks: 0
llm_relevance_filter: false
llm_filter_extraction: false
recency_bias: "no_decay"
document_sets: []
icon_shape: 234124
icon_color: "#9B59B6"
image_generation: true
display_priority: 3
is_visible: true

View File

@@ -51,7 +51,8 @@ from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
@@ -60,6 +61,7 @@ from danswer.search.retrieval.search_runner import inference_sections_from_ids
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
from danswer.search.utils import drop_llm_indices
from danswer.search.utils import relevant_sections_to_indices
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
@@ -187,37 +189,46 @@ def _handle_internet_search_tool_response_summary(
)
def _check_should_force_search(
new_msg_req: CreateChatMessageRequest,
) -> ForceUseTool | None:
# If files are already provided, don't run the search tool
def _get_force_search_settings(
new_msg_req: CreateChatMessageRequest, tools: list[Tool]
) -> ForceUseTool:
internet_search_available = any(
isinstance(tool, InternetSearchTool) for tool in tools
)
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
if not internet_search_available and not search_tool_available:
# Does not matter much which tool is set here as force is false and neither tool is available
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
# Currently, the internet search tool does not support query override
args = (
{"query": new_msg_req.query_override}
if new_msg_req.query_override and tool_name == SearchTool._NAME
else None
)
if new_msg_req.file_descriptors:
return None
# If user has uploaded files they're using, don't run any of the search tools
return ForceUseTool(force_use=False, tool_name=tool_name)
if (
new_msg_req.query_override
or (
should_force_search = any(
[
new_msg_req.retrieval_options
and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS
)
or new_msg_req.search_doc_ids
or DISABLE_LLM_CHOOSE_SEARCH
):
args = (
{"query": new_msg_req.query_override}
if new_msg_req.query_override
else None
)
# if we are using selected docs, just put something here so the Tool doesn't need
# to build its own args via an LLM call
if new_msg_req.search_doc_ids:
args = {"query": new_msg_req.message}
and new_msg_req.retrieval_options.run_search
== OptionalSearchSetting.ALWAYS,
new_msg_req.search_doc_ids,
DISABLE_LLM_CHOOSE_SEARCH,
]
)
return ForceUseTool(
tool_name=SearchTool._NAME,
args=args,
)
return None
if should_force_search:
# If we are using selected docs, just put something here so the Tool doesn't need to build its own args via an LLM call
args = {"query": new_msg_req.message} if new_msg_req.search_doc_ids else args
return ForceUseTool(force_use=True, tool_name=tool_name, args=args)
return ForceUseTool(force_use=False, tool_name=tool_name, args=args)
ChatPacket = (
@@ -253,7 +264,6 @@ def stream_chat_message_objects(
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
try:
user_id = user.id if user is not None else None
@@ -274,7 +284,10 @@ def stream_chat_message_objects(
# use alternate persona if alternative assistant id is passed in
if alternate_assistant_id is not None:
persona = get_persona_by_id(
alternate_assistant_id, user=user, db_session=db_session
alternate_assistant_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
else:
persona = chat_session.persona
@@ -297,7 +310,13 @@ def stream_chat_message_objects(
except GenAIDisabledException:
raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.")
llm_tokenizer = get_default_llm_tokenizer()
llm_provider = llm.config.model_provider
llm_model_name = llm.config.model_name
llm_tokenizer = get_tokenizer(
model_name=llm_model_name,
provider_type=llm_provider,
)
llm_tokenizer_encode_func = cast(
Callable[[str], list[int]], llm_tokenizer.encode
)
@@ -361,6 +380,14 @@ def stream_chat_message_objects(
"when the last message is not a user message."
)
# Disable Query Rephrasing for the first message
# This leads to a better first response since the LLM rephrasing the question
# leads to worst search quality
if not history_msgs:
new_msg_req.query_override = (
new_msg_req.query_override or new_msg_req.message
)
# load all files needed for this chat chain in memory
files = load_all_chat_files(
history_msgs, new_msg_req.file_descriptors, db_session
@@ -476,6 +503,9 @@ def stream_chat_message_objects(
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
evaluation_type=LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP,
)
tool_dict[db_tool_model.id] = [search_tool]
elif tool_cls.__name__ == ImageGenerationTool.__name__:
@@ -544,9 +574,11 @@ def stream_chat_message_objects(
tools.extend(tool_list)
# factor in tool definition size when pruning
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(tools)
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
tools, llm_tokenizer
)
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
llm_provider, llm_model_name
)
# LLM prompt building, response capturing, etc.
@@ -576,11 +608,7 @@ def stream_chat_message_objects(
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
],
tools=tools,
force_use_tool=(
_check_should_force_search(new_msg_req)
if search_tool and len(tools) == 1
else None
),
force_use_tool=_get_force_search_settings(new_msg_req, tools),
)
reference_db_search_docs = None
@@ -606,18 +634,28 @@ def stream_chat_message_objects(
)
yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
chunk_indices = packet.response
relevance_sections = packet.response
if reference_db_search_docs is not None and dropped_indices:
chunk_indices = drop_llm_indices(
llm_indices=chunk_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_indices,
if reference_db_search_docs is not None:
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections,
items=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in reference_db_search_docs
],
)
if dropped_indices:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_indices,
)
yield LLMRelevanceFilterResponse(
relevant_chunk_indices=llm_indices
)
yield LLMRelevanceFilterResponse(
relevant_chunk_indices=chunk_indices
)
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], packet.response
@@ -655,18 +693,28 @@ def stream_chat_message_objects(
yield cast(ChatPacket, packet)
except Exception as e:
logger.exception("Failed to process chat message")
# Don't leak the API key
error_msg = str(e)
if llm.config.api_key and llm.config.api_key.lower() in error_msg.lower():
logger.exception(f"Failed to process chat message: {error_msg}")
if "Illegal header value b'Bearer '" in error_msg:
error_msg = (
f"LLM failed to respond. Invalid API "
f"key error from '{llm.config.model_provider}'."
f"Authentication error: Invalid or empty API key provided for '{llm.config.model_provider}'. "
"Please check your API key configuration."
)
elif (
"Invalid leading whitespace, reserved character(s), or return character(s) in header value"
in error_msg
):
error_msg = (
f"Authentication error: Invalid API key format for '{llm.config.model_provider}'. "
"Please ensure your API key does not contain leading/trailing whitespace or invalid characters."
)
elif llm.config.api_key and llm.config.api_key.lower() in error_msg.lower():
error_msg = f"LLM failed to respond. Invalid API key error from '{llm.config.model_provider}'."
else:
error_msg = "An unexpected error occurred while processing your request. Please try again later."
yield StreamingError(error=error_msg)
# Cancel the transaction so that no messages are saved
db_session.rollback()
return

View File

@@ -30,7 +30,23 @@ prompts:
# Prompts the LLM to include citations in the for [1], [2] etc.
# which get parsed to match the passed in sources
include_citations: true
- name: "ImageGeneration"
description: "Generates images based on user prompts!"
system: >
You are an advanced image generation system capable of creating diverse and detailed images.
You can interpret user prompts and generate high-quality, creative images that match their descriptions.
You always strive to create safe and appropriate content, avoiding any harmful or offensive imagery.
task: >
Generate an image based on the user's description.
Provide a detailed description of the generated image, including key elements, colors, and composition.
If the request is not possible or appropriate, explain why and suggest alternatives.
datetime_aware: true
include_citations: false
- name: "OnlyLLM"
description: "Chat directly with the LLM!"

View File

@@ -129,6 +129,17 @@ POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
# defaults to False
POSTGRES_POOL_PRE_PING = os.environ.get("POSTGRES_POOL_PRE_PING", "").lower() == "true"
# recycle timeout in seconds
POSTGRES_POOL_RECYCLE_DEFAULT = 60 * 20 # 20 minutes
try:
POSTGRES_POOL_RECYCLE = int(
os.environ.get("POSTGRES_POOL_RECYCLE", POSTGRES_POOL_RECYCLE_DEFAULT)
)
except ValueError:
POSTGRES_POOL_RECYCLE = POSTGRES_POOL_RECYCLE_DEFAULT
#####
# Connector Configs
@@ -212,10 +223,11 @@ EXPERIMENTAL_CHECKPOINTING_ENABLED = (
os.environ.get("EXPERIMENTAL_CHECKPOINTING_ENABLED", "").lower() == "true"
)
PRUNING_DISABLED = -1
DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day
PREVENT_SIMULTANEOUS_PRUNING = (
os.environ.get("PREVENT_SIMULTANEOUS_PRUNING", "").lower() == "true"
ALLOW_SIMULTANEOUS_PRUNING = (
os.environ.get("ALLOW_SIMULTANEOUS_PRUNING", "").lower() == "true"
)
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
@@ -248,6 +260,9 @@ DISABLE_INDEX_UPDATE_ON_SWAP = (
# fairly large amount of memory in order to increase substantially, since
# each worker loads the embedding models into memory.
NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1)
NUM_SECONDARY_INDEXING_WORKERS = int(
os.environ.get("NUM_SECONDARY_INDEXING_WORKERS") or NUM_INDEXING_WORKERS
)
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
ENABLE_MINI_CHUNK = os.environ.get("ENABLE_MINI_CHUNK", "").lower() == "true"
# Finer grained chunking for more detail retention

View File

@@ -3,6 +3,7 @@ import os
PROMPTS_YAML = "./danswer/chat/prompts.yaml"
PERSONAS_YAML = "./danswer/chat/personas.yaml"
INPUT_PROMPT_YAML = "./danswer/chat/input_prompts.yaml"
NUM_RETURNED_HITS = 50
# Used for LLM filtering and reranking
@@ -32,11 +33,6 @@ DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
# Note this is not in any of the deployment configs yet
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 0)
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 0)
# Whether the LLM should evaluate all of the document chunks passed in for usefulness
# in relation to the user query
DISABLE_LLM_CHUNK_FILTER = (
os.environ.get("DISABLE_LLM_CHUNK_FILTER", "").lower() == "true"
)
# Whether the LLM should be used to decide if a search would help given the chat history
DISABLE_LLM_CHOOSE_SEARCH = (
os.environ.get("DISABLE_LLM_CHOOSE_SEARCH", "").lower() == "true"
@@ -63,6 +59,7 @@ HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.62)))
TITLE_CONTENT_RATIO = max(
0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.20))
)
# A list of languages passed to the LLM to rephase the query
# For example "English,French,Spanish", be sure to use the "," separator
MULTILINGUAL_QUERY_EXPANSION = os.environ.get("MULTILINGUAL_QUERY_EXPANSION") or None
@@ -75,16 +72,16 @@ LANGUAGE_CHAT_NAMING_HINT = (
or "The name of the conversation must be in the same language as the user query."
)
# Agentic search takes significantly more tokens and therefore has much higher cost.
# This configuration allows users to get a search-only experience with instant results
# and no involvement from the LLM.
# Additionally, some LLM providers have strict rate limits which may prohibit
# sending many API requests at once (as is done in agentic search).
DISABLE_AGENTIC_SEARCH = (
os.environ.get("DISABLE_AGENTIC_SEARCH") or "false"
).lower() == "true"
# Whether the LLM should evaluate all of the document chunks passed in for usefulness
# in relation to the user query
DISABLE_LLM_DOC_RELEVANCE = (
os.environ.get("DISABLE_LLM_DOC_RELEVANCE", "").lower() == "true"
)
# Stops streaming answers back to the UI if this pattern is seen:
STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None

View File

@@ -44,7 +44,6 @@ QUERY_EVENT_ID = "query_event_id"
LLM_CHUNKS = "llm_chunks"
# For chunking/processing chunks
MAX_CHUNK_TITLE_LEN = 1000
RETURN_SEPARATOR = "\n\r\n"
SECTION_SEPARATOR = "\n\n"
# For combining attributes, doesn't have to be unique/perfect to work
@@ -60,6 +59,14 @@ DISABLED_GEN_AI_MSG = (
"You can still use Danswer as a search engine."
)
# Postgres connection constants for application_name
POSTGRES_WEB_APP_NAME = "web"
POSTGRES_INDEXER_APP_NAME = "indexer"
POSTGRES_CELERY_APP_NAME = "celery"
POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
POSTGRES_CELERY_WORKER_APP_NAME = "celery_worker"
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"
# API Keys
DANSWER_API_KEY_PREFIX = "API_KEY__"

View File

@@ -12,7 +12,7 @@ import os
# The useable models configured as below must be SentenceTransformer compatible
# NOTE: DO NOT CHANGE SET THESE UNLESS YOU KNOW WHAT YOU ARE DOING
# IDEALLY, YOU SHOULD CHANGE EMBEDDING MODELS VIA THE UI
DEFAULT_DOCUMENT_ENCODER_MODEL = "intfloat/e5-base-v2"
DEFAULT_DOCUMENT_ENCODER_MODEL = "nomic-ai/nomic-embed-text-v1"
DOCUMENT_ENCODER_MODEL = (
os.environ.get("DOCUMENT_ENCODER_MODEL") or DEFAULT_DOCUMENT_ENCODER_MODEL
)
@@ -34,10 +34,12 @@ OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS = False
SIM_SCORE_RANGE_LOW = float(os.environ.get("SIM_SCORE_RANGE_LOW") or 0.0)
SIM_SCORE_RANGE_HIGH = float(os.environ.get("SIM_SCORE_RANGE_HIGH") or 1.0)
# Certain models like e5, BGE, etc use a prefix for asymmetric retrievals (query generally shorter than docs)
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "query: ")
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "passage: ")
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "search_query: ")
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "search_document: ")
# Purely an optimization, memory limitation consideration
BATCH_SIZE_ENCODE_CHUNKS = 8
# don't send over too many chunks at once, as sending too many could cause timeouts
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES = 512
# For score display purposes, only way is to know the expected ranges
CROSS_ENCODER_RANGE_MAX = 1
CROSS_ENCODER_RANGE_MIN = 0

View File

@@ -217,16 +217,19 @@ class RecursiveIndexer:
self,
batch_size: int,
confluence_client: Confluence,
index_origin: bool,
index_recursively: bool,
origin_page_id: str,
) -> None:
self.batch_size = 1
# batch_size
self.confluence_client = confluence_client
self.index_origin = index_origin
self.index_recursively = index_recursively
self.origin_page_id = origin_page_id
self.pages = self.recurse_children_pages(0, self.origin_page_id)
def get_origin_page(self) -> list[dict[str, Any]]:
return [self._fetch_origin_page()]
def get_pages(self, ind: int, size: int) -> list[dict]:
if ind * size > len(self.pages):
return []
@@ -282,12 +285,11 @@ class RecursiveIndexer:
current_level_pages = next_level_pages
next_level_pages = []
if self.index_origin:
try:
origin_page = self._fetch_origin_page()
pages.append(origin_page)
except Exception as e:
logger.warning(f"Appending origin page with id {page_id} failed: {e}")
try:
origin_page = self._fetch_origin_page()
pages.append(origin_page)
except Exception as e:
logger.warning(f"Appending origin page with id {page_id} failed: {e}")
return pages
@@ -340,7 +342,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
def __init__(
self,
wiki_page_url: str,
index_origin: bool = True,
index_recursively: bool = True,
batch_size: int = INDEX_BATCH_SIZE,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
# if a page has one of the labels specified in this list, we will just
@@ -352,7 +354,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
self.continue_on_failure = continue_on_failure
self.labels_to_skip = set(labels_to_skip)
self.recursive_indexer: RecursiveIndexer | None = None
self.index_origin = index_origin
self.index_recursively = index_recursively
(
self.wiki_base,
self.space,
@@ -369,7 +371,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
logger.info(
f"wiki_base: {self.wiki_base}, space: {self.space}, page_id: {self.page_id},"
+ f" space_level_scan: {self.space_level_scan}, origin: {self.index_origin}"
+ f" space_level_scan: {self.space_level_scan}, index_recursively: {self.index_recursively}"
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
@@ -453,10 +455,13 @@ class ConfluenceConnector(LoadConnector, PollConnector):
origin_page_id=self.page_id,
batch_size=self.batch_size,
confluence_client=self.confluence_client,
index_origin=self.index_origin,
index_recursively=self.index_recursively,
)
return self.recursive_indexer.get_pages(start_ind, batch_size)
if self.index_recursively:
return self.recursive_indexer.get_pages(start_ind, batch_size)
else:
return self.recursive_indexer.get_origin_page()
pages: list[dict[str, Any]] = []

View File

@@ -56,6 +56,16 @@ def extract_text_from_content(content: dict) -> str:
return " ".join(texts)
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
if hasattr(jira_issue.fields, field):
return getattr(jira_issue.fields, field)
try:
return jira_issue.raw["fields"][field]
except Exception:
return None
def _get_comment_strs(
jira: Issue, comment_email_blacklist: tuple[str, ...] = ()
) -> list[str]:
@@ -117,8 +127,10 @@ def fetch_jira_issues_batch(
continue
comments = _get_comment_strs(jira, comment_email_blacklist)
semantic_rep = f"{jira.fields.description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments]
semantic_rep = (
f"{jira.fields.description}\n"
if jira.fields.description
else "" + "\n".join([f"Comment: {comment}" for comment in comments])
)
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
@@ -147,14 +159,18 @@ def fetch_jira_issues_batch(
pass
metadata_dict = {}
if jira.fields.priority:
metadata_dict["priority"] = jira.fields.priority.name
if jira.fields.status:
metadata_dict["status"] = jira.fields.status.name
if jira.fields.resolution:
metadata_dict["resolution"] = jira.fields.resolution.name
if jira.fields.labels:
metadata_dict["label"] = jira.fields.labels
priority = best_effort_get_field_from_issue(jira, "priority")
if priority:
metadata_dict["priority"] = priority.name
status = best_effort_get_field_from_issue(jira, "status")
if status:
metadata_dict["status"] = status.name
resolution = best_effort_get_field_from_issue(jira, "resolution")
if resolution:
metadata_dict["resolution"] = resolution.name
labels = best_effort_get_field_from_issue(jira, "labels")
if labels:
metadata_dict["label"] = labels
doc_batch.append(
Document(

View File

@@ -64,7 +64,7 @@ class DiscourseConnector(PollConnector):
self.permissions: DiscoursePerms | None = None
self.active_categories: set | None = None
@rate_limit_builder(max_calls=100, period=60)
@rate_limit_builder(max_calls=50, period=60)
def _make_request(self, endpoint: str, params: dict | None = None) -> Response:
if not self.permissions:
raise ConnectorMissingCredentialError("Discourse")

View File

@@ -11,6 +11,7 @@ from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from sqlalchemy.orm import Session
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.connectors.gmail.constants import CRED_KEY
from danswer.connectors.gmail.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
@@ -146,6 +147,7 @@ def build_service_account_creds(
credential_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user_email
return CredentialBase(
source=DocumentSource.GMAIL,
credential_json=credential_dict,
admin_public=True,
)

View File

@@ -11,6 +11,7 @@ from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from sqlalchemy.orm import Session
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.connectors.google_drive.constants import CRED_KEY
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
@@ -118,6 +119,7 @@ def update_credential_access_tokens(
def build_service_account_creds(
source: DocumentSource,
delegated_user_email: str | None = None,
) -> CredentialBase:
service_account_key = get_service_account_key()
@@ -131,6 +133,7 @@ def build_service_account_creds(
return CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=DocumentSource.GOOGLE_DRIVE,
)

View File

@@ -86,7 +86,6 @@ class MediaWikiConnector(LoadConnector, PollConnector):
categories: The categories to include in the index.
pages: The pages to include in the index.
recurse_depth: The depth to recurse into categories. -1 means unbounded recursion.
connector_name: The name of the connector.
language_code: The language code of the wiki.
batch_size: The batch size for loading documents.
@@ -104,7 +103,6 @@ class MediaWikiConnector(LoadConnector, PollConnector):
categories: list[str],
pages: list[str],
recurse_depth: int,
connector_name: str,
language_code: str = "en",
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
@@ -118,10 +116,8 @@ class MediaWikiConnector(LoadConnector, PollConnector):
self.batch_size = batch_size
# short names can only have ascii letters and digits
self.connector_name = connector_name
connector_name = "".join(ch for ch in connector_name if ch.isalnum())
self.family = family_class_dispatch(hostname, connector_name)()
self.family = family_class_dispatch(hostname, "Wikipedia Connector")()
self.site = pywikibot.Site(fam=self.family, code=language_code)
self.categories = [
pywikibot.Category(self.site, f"Category:{category.replace(' ', '_')}")
@@ -210,7 +206,6 @@ class MediaWikiConnector(LoadConnector, PollConnector):
if __name__ == "__main__":
HOSTNAME = "fallout.fandom.com"
test_connector = MediaWikiConnector(
connector_name="Fallout",
hostname=HOSTNAME,
categories=["Fallout:_New_Vegas_factions"],
pages=["Fallout: New Vegas"],

View File

@@ -114,7 +114,9 @@ class DocumentBase(BaseModel):
title: str | None = None
from_ingestion_api: bool = False
def get_title_for_document_index(self) -> str | None:
def get_title_for_document_index(
self,
) -> str | None:
# If title is explicitly empty, return a None here for embedding purposes
if self.title == "":
return None

View File

@@ -15,6 +15,7 @@ from playwright.sync_api import BrowserContext
from playwright.sync_api import Playwright
from playwright.sync_api import sync_playwright
from requests_oauthlib import OAuth2Session # type:ignore
from urllib3.exceptions import MaxRetryError
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import WEB_CONNECTOR_OAUTH_CLIENT_ID
@@ -83,6 +84,13 @@ def check_internet_connection(url: str) -> None:
try:
response = requests.get(url, timeout=3)
response.raise_for_status()
except requests.exceptions.SSLError as e:
cause = (
e.args[0].reason
if isinstance(e.args, tuple) and isinstance(e.args[0], MaxRetryError)
else e.args
)
raise Exception(f"SSL error {str(cause)}")
except (requests.RequestException, ValueError):
raise Exception(f"Unable to reach {url} - check your internet connection")

View File

@@ -15,7 +15,6 @@ class WikipediaConnector(wiki.MediaWikiConnector):
categories: list[str],
pages: list[str],
recurse_depth: int,
connector_name: str,
language_code: str = "en",
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
@@ -24,7 +23,6 @@ class WikipediaConnector(wiki.MediaWikiConnector):
categories=categories,
pages=pages,
recurse_depth=recurse_depth,
connector_name=connector_name,
language_code=language_code,
batch_size=batch_size,
)

View File

@@ -1,5 +1,7 @@
from typing import Any
import requests
from retry import retry
from zenpy import Zenpy # type: ignore
from zenpy.lib.api_objects.help_centre_objects import Article # type: ignore
@@ -19,12 +21,24 @@ from danswer.connectors.models import Section
from danswer.file_processing.html_utils import parse_html_page_basic
def _article_to_document(article: Article) -> Document:
def _article_to_document(article: Article, content_tags: dict[str, str]) -> Document:
author = BasicExpertInfo(
display_name=article.author.name, email=article.author.email
)
update_time = time_str_to_utc(article.updated_at)
labels = [str(label) for label in article.label_names]
# build metadata
metadata: dict[str, str | list[str]] = {
"labels": [str(label) for label in article.label_names if label],
"content_tags": [
content_tags[tag_id]
for tag_id in article.content_tag_ids
if tag_id in content_tags
],
}
# remove empty values
metadata = {k: v for k, v in metadata.items() if v}
return Document(
id=f"article:{article.id}",
@@ -35,7 +49,7 @@ def _article_to_document(article: Article) -> Document:
semantic_identifier=article.title,
doc_updated_at=update_time,
primary_owners=[author],
metadata={"labels": labels} if labels else {},
metadata=metadata,
)
@@ -48,6 +62,42 @@ class ZendeskConnector(LoadConnector, PollConnector):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
self.zendesk_client: Zenpy | None = None
self.content_tags: dict[str, str] = {}
@retry(tries=3, delay=2, backoff=2)
def _set_content_tags(
self, subdomain: str, email: str, token: str, page_size: int = 30
) -> None:
# Construct the base URL
base_url = f"https://{subdomain}.zendesk.com/api/v2/guide/content_tags"
# Set up authentication
auth = (f"{email}/token", token)
# Set up pagination parameters
params = {"page[size]": page_size}
try:
while True:
# Make the GET request
response = requests.get(base_url, auth=auth, params=params)
# Check if the request was successful
if response.status_code == 200:
data = response.json()
content_tag_list = data.get("records", [])
for tag in content_tag_list:
self.content_tags[tag["id"]] = tag["name"]
# Check if there are more pages
if data.get("meta", {}).get("has_more", False):
params["page[after]"] = data["meta"]["after_cursor"]
else:
break
else:
raise Exception(f"Error: {response.status_code}\n{response.text}")
except Exception as e:
raise Exception(f"Error fetching content tags: {str(e)}")
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# Subdomain is actually the whole URL
@@ -62,6 +112,11 @@ class ZendeskConnector(LoadConnector, PollConnector):
email=credentials["zendesk_email"],
token=credentials["zendesk_token"],
)
self._set_content_tags(
subdomain,
credentials["zendesk_email"],
credentials["zendesk_token"],
)
return None
def load_from_state(self) -> GenerateDocumentsOutput:
@@ -92,10 +147,30 @@ class ZendeskConnector(LoadConnector, PollConnector):
):
continue
doc_batch.append(_article_to_document(article))
doc_batch.append(_article_to_document(article, self.content_tags))
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch.clear()
if doc_batch:
yield doc_batch
if __name__ == "__main__":
import os
import time
connector = ZendeskConnector()
connector.load_credentials(
{
"zendesk_subdomain": os.environ["ZENDESK_SUBDOMAIN"],
"zendesk_email": os.environ["ZENDESK_EMAIL"],
"zendesk_token": os.environ["ZENDESK_TOKEN"],
}
)
current = time.time()
one_day_ago = current - 24 * 60 * 60 # 1 day
document_batches = connector.poll_source(one_day_ago, current)
print(next(document_batches))

View File

@@ -70,6 +70,10 @@ def _process_citations_for_slack(text: str) -> str:
def slack_link_format(match: Match) -> str:
link_text = match.group(1)
link_url = match.group(2)
# Account for empty link citations
if link_url == "":
return f"[{link_text}]"
return f"<{link_url}|[{link_text}]>"
# Substitute all matches in the input text
@@ -299,7 +303,9 @@ def build_sources_blocks(
else []
)
+ [
MarkdownTextObject(
MarkdownTextObject(text=f"{document_title}")
if d.link == ""
else MarkdownTextObject(
text=f"*<{d.link}|[{citation_num}] {document_title}>*\n{final_metadata_str}"
),
]

View File

@@ -50,9 +50,9 @@ from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_sqlalchemy_engine
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
from danswer.one_shot_answer.models import ThreadMessage
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
@@ -471,8 +471,7 @@ if __name__ == "__main__":
embedding_model = get_current_db_embedding_model(db_session)
if embedding_model.cloud_provider_id is None:
warm_up_encoders(
model_name=embedding_model.model_name,
normalize=embedding_model.normalize,
embedding_model=embedding_model,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)

View File

@@ -16,7 +16,7 @@ from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from danswer.auth.schemas import UserRole
from danswer.chat.models import LLMRelevanceSummaryResponse
from danswer.chat.models import DocumentRelevance
from danswer.configs.chat_configs import HARD_DELETE_CHATS
from danswer.configs.constants import MessageType
from danswer.db.models import ChatMessage
@@ -541,11 +541,11 @@ def get_doc_query_identifiers_from_model(
def update_search_docs_table_with_relevance(
db_session: Session,
reference_db_search_docs: list[SearchDoc],
relevance_summary: LLMRelevanceSummaryResponse,
relevance_summary: DocumentRelevance,
) -> None:
for search_doc in reference_db_search_docs:
relevance_data = relevance_summary.relevance_summaries.get(
f"{search_doc.document_id}-{search_doc.chunk_ind}"
search_doc.document_id
)
if relevance_data is not None:
db_session.execute(

View File

@@ -11,6 +11,7 @@ from danswer.configs.app_configs import DEFAULT_PRUNING_FREQ
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import InputType
from danswer.db.models import Connector
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
from danswer.server.documents.models import ConnectorBase
from danswer.server.documents.models import ObjectCreationIdResponse
@@ -85,9 +86,8 @@ def create_connector(
input_type=connector_data.input_type,
connector_specific_config=connector_data.connector_specific_config,
refresh_freq=connector_data.refresh_freq,
prune_freq=connector_data.prune_freq
if connector_data.prune_freq is not None
else DEFAULT_PRUNING_FREQ,
indexing_start=connector_data.indexing_start,
prune_freq=connector_data.prune_freq,
disabled=connector_data.disabled,
)
db_session.add(connector)
@@ -191,7 +191,8 @@ def fetch_latest_index_attempt_by_connector(
for connector in connectors:
latest_index_attempt = (
db_session.query(IndexAttempt)
.filter(IndexAttempt.connector_id == connector.id)
.join(ConnectorCredentialPair)
.filter(ConnectorCredentialPair.connector_id == connector.id)
.order_by(IndexAttempt.time_updated.desc())
.first()
)
@@ -207,13 +208,11 @@ def fetch_latest_index_attempts_by_status(
) -> list[IndexAttempt]:
subquery = (
db_session.query(
IndexAttempt.connector_id,
IndexAttempt.credential_id,
IndexAttempt.connector_credential_pair_id,
IndexAttempt.status,
func.max(IndexAttempt.time_updated).label("time_updated"),
)
.group_by(IndexAttempt.connector_id)
.group_by(IndexAttempt.credential_id)
.group_by(IndexAttempt.connector_credential_pair_id)
.group_by(IndexAttempt.status)
.subquery()
)
@@ -223,12 +222,13 @@ def fetch_latest_index_attempts_by_status(
query = db_session.query(IndexAttempt).join(
alias,
and_(
IndexAttempt.connector_id == alias.connector_id,
IndexAttempt.credential_id == alias.credential_id,
IndexAttempt.connector_credential_pair_id
== alias.connector_credential_pair_id,
IndexAttempt.status == alias.status,
IndexAttempt.time_updated == alias.time_updated,
),
)
return cast(list[IndexAttempt], query.all())
@@ -247,20 +247,31 @@ def fetch_unique_document_sources(db_session: Session) -> list[DocumentSource]:
def create_initial_default_connector(db_session: Session) -> None:
default_connector_id = 0
default_connector = fetch_connector_by_id(default_connector_id, db_session)
if default_connector is not None:
if (
default_connector.source != DocumentSource.INGESTION_API
or default_connector.input_type != InputType.LOAD_STATE
or default_connector.refresh_freq is not None
or default_connector.disabled
or default_connector.name != "Ingestion API"
or default_connector.connector_specific_config != {}
or default_connector.prune_freq is not None
):
raise ValueError(
"DB is not in a valid initial state. "
"Default connector does not have expected values."
logger.warning(
"Default connector does not have expected values. Updating to proper state."
)
# Ensure default connector has correct valuesg
default_connector.source = DocumentSource.INGESTION_API
default_connector.input_type = InputType.LOAD_STATE
default_connector.refresh_freq = None
default_connector.disabled = False
default_connector.name = "Ingestion API"
default_connector.connector_specific_config = {}
default_connector.prune_freq = None
db_session.commit()
return
# Create a new default connector if it doesn't exist
connector = Connector(
id=default_connector_id,
name="Ingestion API",
@@ -269,6 +280,7 @@ def create_initial_default_connector(db_session: Session) -> None:
connector_specific_config={},
refresh_freq=None,
prune_freq=None,
disabled=False,
)
db_session.add(connector)
db_session.commit()

View File

@@ -6,6 +6,7 @@ from sqlalchemy import desc
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource
from danswer.db.connector import fetch_connector_by_id
from danswer.db.credentials import fetch_credential_by_id
from danswer.db.models import ConnectorCredentialPair
@@ -42,6 +43,17 @@ def get_connector_credential_pair(
return result.scalar_one_or_none()
def get_connector_credential_source_from_id(
cc_pair_id: int,
db_session: Session,
) -> DocumentSource | None:
stmt = select(ConnectorCredentialPair)
stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id)
result = db_session.execute(stmt)
cc_pair = result.scalar_one_or_none()
return cc_pair.connector.source if cc_pair else None
def get_connector_credential_pair_from_id(
cc_pair_id: int,
db_session: Session,
@@ -75,17 +87,23 @@ def get_last_successful_attempt_time(
# For Secondary Index we don't keep track of the latest success, so have to calculate it live
attempt = (
db_session.query(IndexAttempt)
.join(
ConnectorCredentialPair,
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
)
.filter(
IndexAttempt.connector_id == connector_id,
IndexAttempt.credential_id == credential_id,
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
IndexAttempt.embedding_model_id == embedding_model.id,
IndexAttempt.status == IndexingStatus.SUCCESS,
)
.order_by(IndexAttempt.time_started.desc())
.first()
)
if not attempt or not attempt.time_started:
connector = fetch_connector_by_id(connector_id, db_session)
if connector and connector.indexing_start:
return connector.indexing_start.timestamp()
return 0.0
return attempt.time_started.timestamp()
@@ -241,6 +259,12 @@ def remove_credential_from_connector(
)
def fetch_connector_credential_pairs(
db_session: Session,
) -> list[ConnectorCredentialPair]:
return db_session.query(ConnectorCredentialPair).all()
def resync_cc_pair(
cc_pair: ConnectorCredentialPair,
db_session: Session,
@@ -253,10 +277,14 @@ def resync_cc_pair(
) -> IndexAttempt | None:
query = (
db_session.query(IndexAttempt)
.join(
ConnectorCredentialPair,
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
)
.join(EmbeddingModel, IndexAttempt.embedding_model_id == EmbeddingModel.id)
.filter(
IndexAttempt.connector_id == connector_id,
IndexAttempt.credential_id == credential_id,
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
EmbeddingModel.status == IndexModelStatus.PRESENT,
)
)

View File

@@ -2,10 +2,13 @@ from typing import Any
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import and_
from sqlalchemy.sql.expression import or_
from danswer.auth.schemas import UserRole
from danswer.configs.constants import DocumentSource
from danswer.connectors.gmail.constants import (
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
@@ -14,8 +17,10 @@ from danswer.connectors.google_drive.constants import (
)
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential
from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.models import User
from danswer.server.documents.models import CredentialBase
from danswer.server.documents.models import CredentialDataUpdateRequest
from danswer.utils.logger import setup_logger
@@ -74,6 +79,69 @@ def fetch_credential_by_id(
return credential
def fetch_credentials_by_source(
db_session: Session,
user: User | None,
document_source: DocumentSource | None = None,
) -> list[Credential]:
base_query = select(Credential).where(Credential.source == document_source)
base_query = _attach_user_filters(base_query, user)
credentials = db_session.execute(base_query).scalars().all()
return list(credentials)
def swap_credentials_connector(
new_credential_id: int, connector_id: int, user: User | None, db_session: Session
) -> ConnectorCredentialPair:
# Check if the user has permission to use the new credential
new_credential = fetch_credential_by_id(new_credential_id, user, db_session)
if not new_credential:
raise ValueError(
f"No Credential found with id {new_credential_id} or user doesn't have permission to use it"
)
# Existing pair
existing_pair = db_session.execute(
select(ConnectorCredentialPair).where(
ConnectorCredentialPair.connector_id == connector_id
)
).scalar_one_or_none()
if not existing_pair:
raise ValueError(
f"No ConnectorCredentialPair found for connector_id {connector_id}"
)
# Check if the new credential is compatible with the connector
if new_credential.source != existing_pair.connector.source:
raise ValueError(
f"New credential source {new_credential.source} does not match connector source {existing_pair.connector.source}"
)
db_session.execute(
update(DocumentByConnectorCredentialPair)
.where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id
== existing_pair.credential_id,
)
)
.values(credential_id=new_credential_id)
)
# Update the existing pair with the new credential
existing_pair.credential_id = new_credential_id
existing_pair.credential = new_credential
# Commit the changes
db_session.commit()
# Refresh the object to ensure all relationships are up-to-date
db_session.refresh(existing_pair)
return existing_pair
def create_credential(
credential_data: CredentialBase,
user: User | None,
@@ -83,6 +151,8 @@ def create_credential(
credential_json=credential_data.credential_json,
user_id=user.id if user else None,
admin_public=credential_data.admin_public,
source=credential_data.source,
name=credential_data.name,
)
db_session.add(credential)
db_session.commit()
@@ -90,6 +160,28 @@ def create_credential(
return credential
def alter_credential(
credential_id: int,
credential_data: CredentialDataUpdateRequest,
user: User,
db_session: Session,
) -> Credential | None:
credential = fetch_credential_by_id(credential_id, user, db_session)
if credential is None:
return None
credential.name = credential_data.name
# Update only the keys present in credential_data.credential_json
for key, value in credential_data.credential_json.items():
credential.credential_json[key] = value
credential.user_id = user.id if user is not None else None
db_session.commit()
return credential
def update_credential(
credential_id: int,
credential_data: CredentialBase,
@@ -136,6 +228,7 @@ def delete_credential(
credential_id: int,
user: User | None,
db_session: Session,
force: bool = False,
) -> None:
credential = fetch_credential_by_id(credential_id, user, db_session)
if credential is None:
@@ -149,11 +242,38 @@ def delete_credential(
.all()
)
if associated_connectors:
raise ValueError(
f"Cannot delete credential {credential_id} as it is still associated with {len(associated_connectors)} connector(s). "
"Please delete all associated connectors first."
)
associated_doc_cc_pairs = (
db_session.query(DocumentByConnectorCredentialPair)
.filter(DocumentByConnectorCredentialPair.credential_id == credential_id)
.all()
)
if associated_connectors or associated_doc_cc_pairs:
if force:
logger.warning(
f"Force deleting credential {credential_id} and its associated records"
)
# Delete DocumentByConnectorCredentialPair records first
for doc_cc_pair in associated_doc_cc_pairs:
db_session.delete(doc_cc_pair)
# Then delete ConnectorCredentialPair records
for connector in associated_connectors:
db_session.delete(connector)
# Commit these deletions before deleting the credential
db_session.flush()
else:
raise ValueError(
f"Cannot delete credential as it is still associated with "
f"{len(associated_connectors)} connector(s) and {len(associated_doc_cc_pairs)} document(s). "
)
if force:
logger.info(f"Force deleting credential {credential_id}")
else:
logger.info(f"Deleting credential {credential_id}")
db_session.delete(credential)
db_session.commit()

View File

@@ -311,7 +311,7 @@ def acquire_document_locks(db_session: Session, document_ids: list[str]) -> bool
_NUM_LOCK_ATTEMPTS = 10
_LOCK_RETRY_DELAY = 30
_LOCK_RETRY_DELAY = 10
@contextlib.contextmanager

View File

@@ -277,7 +277,7 @@ def mark_cc_pair__document_set_relationships_to_be_deleted__no_commit(
`cc_pair_id` as not current and returns the list of all document set IDs
affected.
NOTE: rases a `ValueError` if any of the document sets are currently syncing
NOTE: raises a `ValueError` if any of the document sets are currently syncing
to avoid getting into a bad state."""
document_set__cc_pair_relationships = db_session.scalars(
select(DocumentSet__ConnectorCredentialPair).where(

View File

@@ -15,7 +15,7 @@ from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexModelStatus
from danswer.indexing.models import EmbeddingModelDetail
from danswer.search.search_nlp_models import clean_model_name
from danswer.natural_language_processing.search_nlp_models import clean_model_name
from danswer.server.manage.embedding.models import (
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
)

View File

@@ -16,8 +16,11 @@ from sqlalchemy.orm import sessionmaker
from danswer.configs.app_configs import POSTGRES_DB
from danswer.configs.app_configs import POSTGRES_HOST
from danswer.configs.app_configs import POSTGRES_PASSWORD
from danswer.configs.app_configs import POSTGRES_POOL_PRE_PING
from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE
from danswer.configs.app_configs import POSTGRES_PORT
from danswer.configs.app_configs import POSTGRES_USER
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -25,12 +28,18 @@ logger = setup_logger()
SYNC_DB_API = "psycopg2"
ASYNC_DB_API = "asyncpg"
POSTGRES_APP_NAME = (
POSTGRES_UNKNOWN_APP_NAME # helps to diagnose open connections in postgres
)
# global so we don't create more than one engine per process
# outside of being best practice, this is needed so we can properly pool
# connections and not create a new pool on every request
_SYNC_ENGINE: Engine | None = None
_ASYNC_ENGINE: AsyncEngine | None = None
SessionFactory: sessionmaker[Session] | None = None
def get_db_current_time(db_session: Session) -> datetime:
"""Get the current time from Postgres representing the start of the transaction
@@ -51,24 +60,50 @@ def build_connection_string(
host: str = POSTGRES_HOST,
port: str = POSTGRES_PORT,
db: str = POSTGRES_DB,
app_name: str | None = None,
) -> str:
if app_name:
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}?application_name={app_name}"
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
def init_sqlalchemy_engine(app_name: str) -> None:
global POSTGRES_APP_NAME
POSTGRES_APP_NAME = app_name
def get_sqlalchemy_engine() -> Engine:
global _SYNC_ENGINE
if _SYNC_ENGINE is None:
connection_string = build_connection_string(db_api=SYNC_DB_API)
_SYNC_ENGINE = create_engine(connection_string, pool_size=40, max_overflow=10)
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=POSTGRES_APP_NAME + "_sync"
)
_SYNC_ENGINE = create_engine(
connection_string,
pool_size=40,
max_overflow=10,
pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE,
)
return _SYNC_ENGINE
def get_sqlalchemy_async_engine() -> AsyncEngine:
global _ASYNC_ENGINE
if _ASYNC_ENGINE is None:
# underlying asyncpg cannot accept application_name directly in the connection string
# https://github.com/MagicStack/asyncpg/issues/798
connection_string = build_connection_string()
_ASYNC_ENGINE = create_async_engine(
connection_string, pool_size=40, max_overflow=10
connection_string,
connect_args={
"server_settings": {"application_name": POSTGRES_APP_NAME + "_async"}
},
pool_size=40,
max_overflow=10,
pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE,
)
return _ASYNC_ENGINE
@@ -115,4 +150,8 @@ async def warm_up_connections(
await async_conn.close()
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
def get_session_factory() -> sessionmaker[Session]:
global SessionFactory
if SessionFactory is None:
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
return SessionFactory

View File

@@ -15,6 +15,7 @@ from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.server.documents.models import ConnectorCredentialPair
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
@@ -23,6 +24,22 @@ from danswer.utils.telemetry import RecordType
logger = setup_logger()
def get_last_attempt_for_cc_pair(
cc_pair_id: int,
embedding_model_id: int,
db_session: Session,
) -> IndexAttempt | None:
return (
db_session.query(IndexAttempt)
.filter(
IndexAttempt.connector_credential_pair_id == cc_pair_id,
IndexAttempt.embedding_model_id == embedding_model_id,
)
.order_by(IndexAttempt.time_updated.desc())
.first()
)
def get_index_attempt(
db_session: Session, index_attempt_id: int
) -> IndexAttempt | None:
@@ -31,15 +48,13 @@ def get_index_attempt(
def create_index_attempt(
connector_id: int,
credential_id: int,
connector_credential_pair_id: int,
embedding_model_id: int,
db_session: Session,
from_beginning: bool = False,
) -> int:
new_attempt = IndexAttempt(
connector_id=connector_id,
credential_id=credential_id,
connector_credential_pair_id=connector_credential_pair_id,
embedding_model_id=embedding_model_id,
from_beginning=from_beginning,
status=IndexingStatus.NOT_STARTED,
@@ -56,7 +71,9 @@ def get_inprogress_index_attempts(
) -> list[IndexAttempt]:
stmt = select(IndexAttempt)
if connector_id is not None:
stmt = stmt.where(IndexAttempt.connector_id == connector_id)
stmt = stmt.where(
IndexAttempt.connector_credential_pair.has(connector_id=connector_id)
)
stmt = stmt.where(IndexAttempt.status == IndexingStatus.IN_PROGRESS)
incomplete_attempts = db_session.scalars(stmt)
@@ -65,21 +82,31 @@ def get_inprogress_index_attempts(
def get_not_started_index_attempts(db_session: Session) -> list[IndexAttempt]:
"""This eagerly loads the connector and credential so that the db_session can be expired
before running long-living indexing jobs, which causes increasing memory usage"""
before running long-living indexing jobs, which causes increasing memory usage.
Results are ordered by time_created (oldest to newest)."""
stmt = select(IndexAttempt)
stmt = stmt.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
stmt = stmt.order_by(IndexAttempt.time_created)
stmt = stmt.options(
joinedload(IndexAttempt.connector), joinedload(IndexAttempt.credential)
joinedload(IndexAttempt.connector_credential_pair).joinedload(
ConnectorCredentialPair.connector
),
joinedload(IndexAttempt.connector_credential_pair).joinedload(
ConnectorCredentialPair.credential
),
)
new_attempts = db_session.scalars(stmt)
return list(new_attempts.all())
def mark_attempt_in_progress__no_commit(
def mark_attempt_in_progress(
index_attempt: IndexAttempt,
db_session: Session,
) -> None:
index_attempt.status = IndexingStatus.IN_PROGRESS
index_attempt.time_started = index_attempt.time_started or func.now() # type: ignore
db_session.commit()
def mark_attempt_succeeded(
@@ -103,7 +130,7 @@ def mark_attempt_failed(
db_session.add(index_attempt)
db_session.commit()
source = index_attempt.connector.source
source = index_attempt.connector_credential_pair.connector.source
optional_telemetry(record_type=RecordType.FAILURE, data={"connector": source})
@@ -128,11 +155,16 @@ def get_last_attempt(
embedding_model_id: int | None,
db_session: Session,
) -> IndexAttempt | None:
stmt = select(IndexAttempt).where(
IndexAttempt.connector_id == connector_id,
IndexAttempt.credential_id == credential_id,
IndexAttempt.embedding_model_id == embedding_model_id,
stmt = (
select(IndexAttempt)
.join(ConnectorCredentialPair)
.where(
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
IndexAttempt.embedding_model_id == embedding_model_id,
)
)
# Note, the below is using time_created instead of time_updated
stmt = stmt.order_by(desc(IndexAttempt.time_created))
@@ -145,8 +177,7 @@ def get_latest_index_attempts(
db_session: Session,
) -> Sequence[IndexAttempt]:
ids_stmt = select(
IndexAttempt.connector_id,
IndexAttempt.credential_id,
IndexAttempt.connector_credential_pair_id,
func.max(IndexAttempt.time_created).label("max_time_created"),
).join(EmbeddingModel, IndexAttempt.embedding_model_id == EmbeddingModel.id)
@@ -158,43 +189,101 @@ def get_latest_index_attempts(
where_stmts: list[ColumnElement] = []
for connector_credential_pair_identifier in connector_credential_pair_identifiers:
where_stmts.append(
and_(
IndexAttempt.connector_id
== connector_credential_pair_identifier.connector_id,
IndexAttempt.credential_id
== connector_credential_pair_identifier.credential_id,
IndexAttempt.connector_credential_pair_id
== (
select(ConnectorCredentialPair.id)
.where(
ConnectorCredentialPair.connector_id
== connector_credential_pair_identifier.connector_id,
ConnectorCredentialPair.credential_id
== connector_credential_pair_identifier.credential_id,
)
.scalar_subquery()
)
)
if where_stmts:
ids_stmt = ids_stmt.where(or_(*where_stmts))
ids_stmt = ids_stmt.group_by(IndexAttempt.connector_id, IndexAttempt.credential_id)
ids_subqery = ids_stmt.subquery()
ids_stmt = ids_stmt.group_by(IndexAttempt.connector_credential_pair_id)
ids_subquery = ids_stmt.subquery()
stmt = (
select(IndexAttempt)
.join(
ids_subqery,
and_(
ids_subqery.c.connector_id == IndexAttempt.connector_id,
ids_subqery.c.credential_id == IndexAttempt.credential_id,
),
ids_subquery,
IndexAttempt.connector_credential_pair_id
== ids_subquery.c.connector_credential_pair_id,
)
.where(IndexAttempt.time_created == ids_subqery.c.max_time_created)
.where(IndexAttempt.time_created == ids_subquery.c.max_time_created)
)
return db_session.execute(stmt).scalars().all()
def get_index_attempts_for_connector(
db_session: Session,
connector_id: int,
only_current: bool = True,
disinclude_finished: bool = False,
) -> Sequence[IndexAttempt]:
stmt = (
select(IndexAttempt)
.join(ConnectorCredentialPair)
.where(ConnectorCredentialPair.connector_id == connector_id)
)
if disinclude_finished:
stmt = stmt.where(
IndexAttempt.status.in_(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
)
)
if only_current:
stmt = stmt.join(EmbeddingModel).where(
EmbeddingModel.status == IndexModelStatus.PRESENT
)
stmt = stmt.order_by(IndexAttempt.time_created.desc())
return db_session.execute(stmt).scalars().all()
def get_latest_finished_index_attempt_for_cc_pair(
connector_credential_pair_id: int,
secondary_index: bool,
db_session: Session,
) -> IndexAttempt | None:
stmt = select(IndexAttempt).where(
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
IndexAttempt.status.not_in(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
),
)
if secondary_index:
stmt = stmt.join(EmbeddingModel).where(
EmbeddingModel.status == IndexModelStatus.FUTURE
)
else:
stmt = stmt.join(EmbeddingModel).where(
EmbeddingModel.status == IndexModelStatus.PRESENT
)
stmt = stmt.order_by(desc(IndexAttempt.time_created))
stmt = stmt.limit(1)
return db_session.execute(stmt).scalar_one_or_none()
def get_index_attempts_for_cc_pair(
db_session: Session,
cc_pair_identifier: ConnectorCredentialPairIdentifier,
only_current: bool = True,
disinclude_finished: bool = False,
) -> Sequence[IndexAttempt]:
stmt = select(IndexAttempt).where(
and_(
IndexAttempt.connector_id == cc_pair_identifier.connector_id,
IndexAttempt.credential_id == cc_pair_identifier.credential_id,
stmt = (
select(IndexAttempt)
.join(ConnectorCredentialPair)
.where(
and_(
ConnectorCredentialPair.connector_id == cc_pair_identifier.connector_id,
ConnectorCredentialPair.credential_id
== cc_pair_identifier.credential_id,
)
)
)
if disinclude_finished:
@@ -218,9 +307,11 @@ def delete_index_attempts(
db_session: Session,
) -> None:
stmt = delete(IndexAttempt).where(
IndexAttempt.connector_id == connector_id,
IndexAttempt.credential_id == credential_id,
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
)
db_session.execute(stmt)
@@ -254,9 +345,11 @@ def cancel_indexing_attempts_for_connector(
db_session: Session,
include_secondary_index: bool = False,
) -> None:
stmt = delete(IndexAttempt).where(
IndexAttempt.connector_id == connector_id,
IndexAttempt.status == IndexingStatus.NOT_STARTED,
stmt = (
delete(IndexAttempt)
.where(IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id)
.where(ConnectorCredentialPair.connector_id == connector_id)
.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
)
if not include_secondary_index:
@@ -296,7 +389,8 @@ def count_unique_cc_pairs_with_successful_index_attempts(
Then do distinct by connector_id and credential_id which is equivalent to the cc-pair. Finally,
do a count to get the total number of unique cc-pairs with successful attempts"""
unique_pairs_count = (
db_session.query(IndexAttempt.connector_id, IndexAttempt.credential_id)
db_session.query(IndexAttempt.connector_credential_pair_id)
.join(ConnectorCredentialPair)
.filter(
IndexAttempt.embedding_model_id == embedding_model_id,
IndexAttempt.status == IndexingStatus.SUCCESS,

View File

@@ -0,0 +1,202 @@
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import InputPrompt
from danswer.db.models import User
from danswer.server.features.input_prompt.models import InputPromptSnapshot
from danswer.server.manage.models import UserInfo
from danswer.utils.logger import setup_logger
logger = setup_logger()
def insert_input_prompt_if_not_exists(
user: User | None,
input_prompt_id: int | None,
prompt: str,
content: str,
active: bool,
is_public: bool,
db_session: Session,
commit: bool = True,
) -> InputPrompt:
if input_prompt_id is not None:
input_prompt = (
db_session.query(InputPrompt).filter_by(id=input_prompt_id).first()
)
else:
query = db_session.query(InputPrompt).filter(InputPrompt.prompt == prompt)
if user:
query = query.filter(InputPrompt.user_id == user.id)
else:
query = query.filter(InputPrompt.user_id.is_(None))
input_prompt = query.first()
if input_prompt is None:
input_prompt = InputPrompt(
id=input_prompt_id,
prompt=prompt,
content=content,
active=active,
is_public=is_public or user is None,
user_id=user.id if user else None,
)
db_session.add(input_prompt)
if commit:
db_session.commit()
return input_prompt
def insert_input_prompt(
prompt: str,
content: str,
is_public: bool,
user: User | None,
db_session: Session,
) -> InputPrompt:
input_prompt = InputPrompt(
prompt=prompt,
content=content,
active=True,
is_public=is_public or user is None,
user_id=user.id if user is not None else None,
)
db_session.add(input_prompt)
db_session.commit()
return input_prompt
def update_input_prompt(
user: User | None,
input_prompt_id: int,
prompt: str,
content: str,
active: bool,
db_session: Session,
) -> InputPrompt:
input_prompt = db_session.scalar(
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
)
if input_prompt is None:
raise ValueError(f"No input prompt with id {input_prompt_id}")
if not validate_user_prompt_authorization(user, input_prompt):
raise HTTPException(status_code=401, detail="You don't own this prompt")
input_prompt.prompt = prompt
input_prompt.content = content
input_prompt.active = active
db_session.commit()
return input_prompt
def validate_user_prompt_authorization(
user: User | None, input_prompt: InputPrompt
) -> bool:
prompt = InputPromptSnapshot.from_model(input_prompt=input_prompt)
if prompt.user_id is not None:
if user is None:
return False
user_details = UserInfo.from_model(user)
if str(user_details.id) != str(prompt.user_id):
return False
return True
def remove_public_input_prompt(input_prompt_id: int, db_session: Session) -> None:
input_prompt = db_session.scalar(
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
)
if input_prompt is None:
raise ValueError(f"No input prompt with id {input_prompt_id}")
if not input_prompt.is_public:
raise HTTPException(status_code=400, detail="This prompt is not public")
db_session.delete(input_prompt)
db_session.commit()
def remove_input_prompt(
user: User | None, input_prompt_id: int, db_session: Session
) -> None:
input_prompt = db_session.scalar(
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
)
if input_prompt is None:
raise ValueError(f"No input prompt with id {input_prompt_id}")
if input_prompt.is_public:
raise HTTPException(
status_code=400, detail="Cannot delete public prompts with this method"
)
if not validate_user_prompt_authorization(user, input_prompt):
raise HTTPException(status_code=401, detail="You do not own this prompt")
db_session.delete(input_prompt)
db_session.commit()
def fetch_input_prompt_by_id(
id: int, user_id: UUID | None, db_session: Session
) -> InputPrompt:
query = select(InputPrompt).where(InputPrompt.id == id)
if user_id:
query = query.where(
(InputPrompt.user_id == user_id) | (InputPrompt.user_id is None)
)
else:
# If no user_id is provided, only fetch prompts without a user_id (aka public)
query = query.where(InputPrompt.user_id == None) # noqa
result = db_session.scalar(query)
if result is None:
raise HTTPException(422, "No input prompt found")
return result
def fetch_public_input_prompts(
db_session: Session,
) -> list[InputPrompt]:
query = select(InputPrompt).where(InputPrompt.is_public)
return list(db_session.scalars(query).all())
def fetch_input_prompts_by_user(
db_session: Session,
user_id: UUID | None,
active: bool | None = None,
include_public: bool = False,
) -> list[InputPrompt]:
query = select(InputPrompt)
if user_id is not None:
if include_public:
query = query.where(
(InputPrompt.user_id == user_id) | InputPrompt.is_public
)
else:
query = query.where(InputPrompt.user_id == user_id)
elif include_public:
query = query.where(InputPrompt.is_public)
if active is not None:
query = query.where(InputPrompt.active == active)
return list(db_session.scalars(query).all())

View File

@@ -1,15 +1,41 @@
from sqlalchemy import delete
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from danswer.db.models import LLMProvider as LLMProviderModel
from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from danswer.server.manage.llm.models import FullLLMProvider
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
def update_group_llm_provider_relationships__no_commit(
llm_provider_id: int,
group_ids: list[int] | None,
db_session: Session,
) -> None:
# Delete existing relationships
db_session.query(LLMProvider__UserGroup).filter(
LLMProvider__UserGroup.llm_provider_id == llm_provider_id
).delete(synchronize_session="fetch")
# Add new relationships from given group_ids
if group_ids:
new_relationships = [
LLMProvider__UserGroup(
llm_provider_id=llm_provider_id,
user_group_id=group_id,
)
for group_id in group_ids
]
db_session.add_all(new_relationships)
def upsert_cloud_embedding_provider(
db_session: Session, provider: CloudEmbeddingProviderCreationRequest
) -> CloudEmbeddingProvider:
@@ -36,36 +62,36 @@ def upsert_llm_provider(
existing_llm_provider = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
)
if existing_llm_provider:
existing_llm_provider.provider = llm_provider.provider
existing_llm_provider.api_key = llm_provider.api_key
existing_llm_provider.api_base = llm_provider.api_base
existing_llm_provider.api_version = llm_provider.api_version
existing_llm_provider.custom_config = llm_provider.custom_config
existing_llm_provider.default_model_name = llm_provider.default_model_name
existing_llm_provider.fast_default_model_name = (
llm_provider.fast_default_model_name
)
existing_llm_provider.model_names = llm_provider.model_names
db_session.commit()
return FullLLMProvider.from_model(existing_llm_provider)
# if it does not exist, create a new entry
llm_provider_model = LLMProviderModel(
name=llm_provider.name,
provider=llm_provider.provider,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
default_model_name=llm_provider.default_model_name,
fast_default_model_name=llm_provider.fast_default_model_name,
model_names=llm_provider.model_names,
is_default_provider=None,
if not existing_llm_provider:
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
db_session.add(existing_llm_provider)
existing_llm_provider.provider = llm_provider.provider
existing_llm_provider.api_key = llm_provider.api_key
existing_llm_provider.api_base = llm_provider.api_base
existing_llm_provider.api_version = llm_provider.api_version
existing_llm_provider.custom_config = llm_provider.custom_config
existing_llm_provider.default_model_name = llm_provider.default_model_name
existing_llm_provider.fast_default_model_name = llm_provider.fast_default_model_name
existing_llm_provider.model_names = llm_provider.model_names
existing_llm_provider.is_public = llm_provider.is_public
existing_llm_provider.display_model_names = llm_provider.display_model_names
if not existing_llm_provider.id:
# If its not already in the db, we need to generate an ID by flushing
db_session.flush()
# Make sure the relationship table stays up to date
update_group_llm_provider_relationships__no_commit(
llm_provider_id=existing_llm_provider.id,
group_ids=llm_provider.groups,
db_session=db_session,
)
db_session.add(llm_provider_model)
db_session.commit()
return FullLLMProvider.from_model(llm_provider_model)
return FullLLMProvider.from_model(existing_llm_provider)
def fetch_existing_embedding_providers(
@@ -74,8 +100,29 @@ def fetch_existing_embedding_providers(
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
return list(db_session.scalars(select(LLMProviderModel)).all())
def fetch_existing_llm_providers(
db_session: Session,
user: User | None = None,
) -> list[LLMProviderModel]:
if not user:
return list(db_session.scalars(select(LLMProviderModel)).all())
stmt = select(LLMProviderModel).distinct()
user_groups_subquery = (
select(User__UserGroup.user_group_id)
.where(User__UserGroup.user_id == user.id)
.subquery()
)
access_conditions = or_(
LLMProviderModel.is_public,
LLMProviderModel.id.in_( # User is part of a group that has access
select(LLMProvider__UserGroup.llm_provider_id).where(
LLMProvider__UserGroup.user_group_id.in_(user_groups_subquery) # type: ignore
)
),
)
stmt = stmt.where(access_conditions)
return list(db_session.scalars(stmt).all())
def fetch_embedding_provider(
@@ -119,6 +166,13 @@ def remove_embedding_provider(
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
# Remove LLMProvider's dependent relationships
db_session.execute(
delete(LLMProvider__UserGroup).where(
LLMProvider__UserGroup.llm_provider_id == provider_id
)
)
# Remove LLMProvider
db_session.execute(
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
)

View File

@@ -11,6 +11,7 @@ from uuid import UUID
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseOAuthAccountTableUUID
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTableUUID
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID
from fastapi_users_db_sqlalchemy.generics import TIMESTAMPAware
from sqlalchemy import Boolean
from sqlalchemy import DateTime
from sqlalchemy import Enum
@@ -120,6 +121,10 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
postgresql.ARRAY(Integer), nullable=True
)
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
TIMESTAMPAware(timezone=True), nullable=True
)
# relationships
credentials: Mapped[list["Credential"]] = relationship(
"Credential", back_populates="user", lazy="joined"
@@ -132,12 +137,39 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
)
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
input_prompts: Mapped[list["InputPrompt"]] = relationship(
"InputPrompt", back_populates="user"
)
# Personas owned by this user
personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user")
# Custom tools created by this user
custom_tools: Mapped[list["Tool"]] = relationship("Tool", back_populates="user")
class InputPrompt(Base):
__tablename__ = "inputprompt"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
prompt: Mapped[str] = mapped_column(String)
content: Mapped[str] = mapped_column(String)
active: Mapped[bool] = mapped_column(Boolean)
user: Mapped[User | None] = relationship("User", back_populates="input_prompts")
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"))
class InputPrompt__User(Base):
__tablename__ = "inputprompt__user"
input_prompt_id: Mapped[int] = mapped_column(
ForeignKey("inputprompt.id"), primary_key=True
)
user_id: Mapped[UUID] = mapped_column(
ForeignKey("inputprompt.id"), primary_key=True
)
class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
pass
@@ -337,6 +369,9 @@ class ConnectorCredentialPair(Base):
back_populates="connector_credential_pairs",
overlaps="document_set",
)
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
"IndexAttempt", back_populates="connector_credential_pair"
)
class Document(Base):
@@ -416,6 +451,9 @@ class Connector(Base):
connector_specific_config: Mapped[dict[str, Any]] = mapped_column(
postgresql.JSONB()
)
indexing_start: Mapped[datetime.datetime | None] = mapped_column(
DateTime, nullable=True
)
refresh_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
prune_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
time_created: Mapped[datetime.datetime] = mapped_column(
@@ -434,14 +472,17 @@ class Connector(Base):
documents_by_connector: Mapped[
list["DocumentByConnectorCredentialPair"]
] = relationship("DocumentByConnectorCredentialPair", back_populates="connector")
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
"IndexAttempt", back_populates="connector"
)
class Credential(Base):
__tablename__ = "credential"
name: Mapped[str] = mapped_column(String, nullable=True)
source: Mapped[DocumentSource] = mapped_column(
Enum(DocumentSource, native_enum=False)
)
id: Mapped[int] = mapped_column(primary_key=True)
credential_json: Mapped[dict[str, Any]] = mapped_column(EncryptedJson())
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
@@ -462,9 +503,7 @@ class Credential(Base):
documents_by_credential: Mapped[
list["DocumentByConnectorCredentialPair"]
] = relationship("DocumentByConnectorCredentialPair", back_populates="credential")
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
"IndexAttempt", back_populates="credential"
)
user: Mapped[User | None] = relationship("User", back_populates="credentials")
@@ -516,12 +555,12 @@ class EmbeddingModel(Base):
cloud_provider='{self.cloud_provider.name if self.cloud_provider else 'None'}')>"
@property
def api_key(self) -> str | None:
return self.cloud_provider.api_key if self.cloud_provider else None
def provider_type(self) -> str | None:
return self.cloud_provider.name if self.cloud_provider is not None else None
@property
def provider_type(self) -> str | None:
return self.cloud_provider.name if self.cloud_provider else None
def api_key(self) -> str | None:
return self.cloud_provider.api_key if self.cloud_provider is not None else None
class IndexAttempt(Base):
@@ -534,13 +573,10 @@ class IndexAttempt(Base):
__tablename__ = "index_attempt"
id: Mapped[int] = mapped_column(primary_key=True)
connector_id: Mapped[int | None] = mapped_column(
ForeignKey("connector.id"),
nullable=True,
)
credential_id: Mapped[int | None] = mapped_column(
ForeignKey("credential.id"),
nullable=True,
connector_credential_pair_id: Mapped[int] = mapped_column(
ForeignKey("connector_credential_pair.id"),
nullable=False,
)
# Some index attempts that run from beginning will still have this as False
@@ -578,12 +614,10 @@ class IndexAttempt(Base):
onupdate=func.now(),
)
connector: Mapped[Connector] = relationship(
"Connector", back_populates="index_attempts"
)
credential: Mapped[Credential] = relationship(
"Credential", back_populates="index_attempts"
connector_credential_pair: Mapped[ConnectorCredentialPair] = relationship(
"ConnectorCredentialPair", back_populates="index_attempts"
)
embedding_model: Mapped[EmbeddingModel] = relationship(
"EmbeddingModel", back_populates="index_attempts"
)
@@ -591,8 +625,7 @@ class IndexAttempt(Base):
__table_args__ = (
Index(
"ix_index_attempt_latest_for_connector_credential_pair",
"connector_id",
"credential_id",
"connector_credential_pair_id",
"time_created",
),
)
@@ -600,7 +633,6 @@ class IndexAttempt(Base):
def __repr__(self) -> str:
return (
f"<IndexAttempt(id={self.id!r}, "
f"connector_id={self.connector_id!r}, "
f"status={self.status!r}, "
f"error_msg={self.error_msg!r})>"
f"time_created={self.time_created!r}, "
@@ -821,6 +853,8 @@ class ChatMessage(Base):
secondary="chat_message__search_doc",
back_populates="chat_messages",
)
# NOTE: Should always be attached to the `assistant` message.
# represents the tool calls used to generate this message
tool_calls: Mapped[list["ToolCall"]] = relationship(
"ToolCall",
back_populates="message",
@@ -923,6 +957,11 @@ class LLMProvider(Base):
default_model_name: Mapped[str] = mapped_column(String)
fast_default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
# Models to actually disp;aly to users
# If nulled out, we assume in the application logic we should present all
display_model_names: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
)
# The LLMs that are available for this provider. Only required if not a default provider.
# If a default provider, then the LLM options are pulled from the `options.py` file.
# If needed, can be pulled out as a separate table in the future.
@@ -932,6 +971,13 @@ class LLMProvider(Base):
# should only be set for a single provider
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
# EE only
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
groups: Mapped[list["UserGroup"]] = relationship(
"UserGroup",
secondary="llm_provider__user_group",
viewonly=True,
)
class CloudEmbeddingProvider(Base):
@@ -1107,9 +1153,14 @@ class Persona(Base):
# controls the ordering of personas in the UI
# higher priority personas are displayed first, ties are resolved by the ID,
# where lower value IDs (e.g. created earlier) are displayed first
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=None)
display_priority: Mapped[int | None] = mapped_column(
Integer, nullable=True, default=None
)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
uploaded_image_id: Mapped[str | None] = mapped_column(String, nullable=True)
icon_color: Mapped[str | None] = mapped_column(String, nullable=True)
icon_shape: Mapped[int | None] = mapped_column(Integer, nullable=True)
# These are only defaults, users can select from all if desired
prompts: Mapped[list[Prompt]] = relationship(
@@ -1137,6 +1188,7 @@ class Persona(Base):
viewonly=True,
)
# EE only
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
groups: Mapped[list["UserGroup"]] = relationship(
"UserGroup",
secondary="persona__user_group",
@@ -1360,6 +1412,17 @@ class Persona__UserGroup(Base):
)
class LLMProvider__UserGroup(Base):
__tablename__ = "llm_provider__user_group"
llm_provider_id: Mapped[int] = mapped_column(
ForeignKey("llm_provider.id"), primary_key=True
)
user_group_id: Mapped[int] = mapped_column(
ForeignKey("user_group.id"), primary_key=True
)
class DocumentSet__UserGroup(Base):
__tablename__ = "document_set__user_group"

View File

@@ -9,6 +9,7 @@ from sqlalchemy import not_
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from danswer.auth.schemas import UserRole
@@ -24,6 +25,7 @@ from danswer.db.models import StarterMessage
from danswer.db.models import Tool
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.db.models import UserGroup
from danswer.search.enums import RecencyBiasSetting
from danswer.server.features.persona.models import CreatePersonaRequest
from danswer.server.features.persona.models import PersonaSnapshot
@@ -80,6 +82,9 @@ def create_update_persona(
starter_messages=create_persona_request.starter_messages,
is_public=create_persona_request.is_public,
db_session=db_session,
icon_color=create_persona_request.icon_color,
icon_shape=create_persona_request.icon_shape,
uploaded_image_id=create_persona_request.uploaded_image_id,
)
versioned_make_persona_private = fetch_versioned_implementation(
@@ -328,6 +333,11 @@ def upsert_persona(
persona_id: int | None = None,
default_persona: bool = False,
commit: bool = True,
icon_color: str | None = None,
icon_shape: int | None = None,
uploaded_image_id: str | None = None,
display_priority: int | None = None,
is_visible: bool = True,
) -> Persona:
if persona_id is not None:
persona = db_session.query(Persona).filter_by(id=persona_id).first()
@@ -383,6 +393,11 @@ def upsert_persona(
persona.starter_messages = starter_messages
persona.deleted = False # Un-delete if previously deleted
persona.is_public = is_public
persona.icon_color = icon_color
persona.icon_shape = icon_shape
persona.uploaded_image_id = uploaded_image_id
persona.display_priority = display_priority
persona.is_visible = is_visible
# Do not delete any associations manually added unless
# a new updated list is provided
@@ -415,6 +430,11 @@ def upsert_persona(
llm_model_version_override=llm_model_version_override,
starter_messages=starter_messages,
tools=tools or [],
icon_shape=icon_shape,
icon_color=icon_color,
uploaded_image_id=uploaded_image_id,
display_priority=display_priority,
is_visible=is_visible,
)
db_session.add(persona)
@@ -548,6 +568,8 @@ def get_default_prompt__read_only() -> Prompt:
return _get_default_prompt(db_session)
# TODO: since this gets called with every chat message, could it be more efficient to pregenerate
# a direct mapping indicating whether a user has access to a specific persona?
def get_persona_by_id(
persona_id: int,
# if user is `None` assume the user is an admin or auth is disabled
@@ -556,16 +578,38 @@ def get_persona_by_id(
include_deleted: bool = False,
is_for_edit: bool = True, # NOTE: assume true for safety
) -> Persona:
stmt = select(Persona).where(Persona.id == persona_id)
stmt = (
select(Persona)
.options(selectinload(Persona.users), selectinload(Persona.groups))
.where(Persona.id == persona_id)
)
or_conditions = []
# if user is an admin, they should have access to all Personas
# and will skip the following clause
if user is not None and user.role != UserRole.ADMIN:
or_conditions.extend([Persona.user_id == user.id, Persona.user_id.is_(None)])
# the user is not an admin
isPersonaUnowned = Persona.user_id.is_(
None
) # allow access if persona user id is None
isUserCreator = (
Persona.user_id == user.id
) # allow access if user created the persona
or_conditions.extend([isPersonaUnowned, isUserCreator])
# if we aren't editing, also give access to all public personas
# if we aren't editing, also give access if:
# 1. the user is authorized for this persona
# 2. the user is in an authorized group for this persona
# 3. if the persona is public
if not is_for_edit:
isSharedWithUser = Persona.users.any(
id=user.id
) # allow access if user is in allowed users
isSharedWithGroup = Persona.groups.any(
UserGroup.users.any(id=user.id)
) # allow access if user is in any allowed group
or_conditions.extend([isSharedWithUser, isSharedWithGroup])
or_conditions.append(Persona.is_public.is_(True))
if or_conditions:

View File

@@ -7,6 +7,7 @@ from danswer.access.models import DocumentAccess
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunkUncleaned
from shared_configs.model_server_models import Embedding
@dataclass(frozen=True)
@@ -257,7 +258,7 @@ class VectorCapable(abc.ABC):
def semantic_retrieval(
self,
query: str, # Needed for matching purposes
query_embedding: list[float],
query_embedding: Embedding,
filters: IndexFilters,
time_decay_multiplier: float,
num_to_retrieve: int,
@@ -292,7 +293,7 @@ class HybridCapable(abc.ABC):
def hybrid_retrieval(
self,
query: str,
query_embedding: list[float],
query_embedding: Embedding,
filters: IndexFilters,
time_decay_multiplier: float,
num_to_retrieve: int,

View File

@@ -69,6 +69,7 @@ from danswer.search.retrieval.search_runner import query_processing
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
from shared_configs.model_server_models import Embedding
logger = setup_logger()
@@ -329,11 +330,13 @@ def _index_vespa_chunk(
"Content-Type": "application/json",
}
document = chunk.source_document
# No minichunk documents in vespa, minichunk vectors are stored in the chunk itself
vespa_chunk_id = str(get_uuid_from_chunk(chunk))
embeddings = chunk.embeddings
embeddings_name_vector_map = {"full_chunk": embeddings.full_embedding}
if embeddings.mini_chunk_embeddings:
for ind, m_c_embed in enumerate(embeddings.mini_chunk_embeddings):
embeddings_name_vector_map[f"mini_chunk_{ind}"] = m_c_embed
@@ -346,11 +349,15 @@ def _index_vespa_chunk(
BLURB: remove_invalid_unicode_chars(chunk.blurb),
TITLE: remove_invalid_unicode_chars(title) if title else None,
SKIP_TITLE_EMBEDDING: not title,
CONTENT: remove_invalid_unicode_chars(chunk.content),
# For the BM25 index, the keyword suffix is used, the vector is already generated with the more
# natural language representation of the metadata section
CONTENT: remove_invalid_unicode_chars(
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_keyword}"
),
# This duplication of `content` is needed for keyword highlighting
# Note that it's not exactly the same as the actual content
# which contains the title prefix and metadata suffix
CONTENT_SUMMARY: remove_invalid_unicode_chars(chunk.content_summary),
CONTENT_SUMMARY: remove_invalid_unicode_chars(chunk.content),
SOURCE_TYPE: str(document.source.value),
SOURCE_LINKS: json.dumps(chunk.source_links),
SEMANTIC_IDENTIFIER: remove_invalid_unicode_chars(document.semantic_identifier),
@@ -358,7 +365,7 @@ def _index_vespa_chunk(
METADATA: json.dumps(document.metadata),
# Save as a list for efficient extraction as an Attribute
METADATA_LIST: chunk.source_document.get_metadata_str_attributes(),
METADATA_SUFFIX: chunk.metadata_suffix,
METADATA_SUFFIX: chunk.metadata_suffix_keyword,
EMBEDDINGS: embeddings_name_vector_map,
TITLE_EMBEDDING: chunk.title_embedding,
BOOST: chunk.boost,
@@ -1025,7 +1032,7 @@ class VespaIndex(DocumentIndex):
def semantic_retrieval(
self,
query: str,
query_embedding: list[float],
query_embedding: Embedding,
filters: IndexFilters,
time_decay_multiplier: float,
num_to_retrieve: int = NUM_RETURNED_HITS,
@@ -1067,7 +1074,7 @@ class VespaIndex(DocumentIndex):
def hybrid_retrieval(
self,
query: str,
query_embedding: list[float],
query_embedding: Embedding,
filters: IndexFilters,
time_decay_multiplier: float,
num_to_retrieve: int,

View File

@@ -103,6 +103,8 @@ def port_api_key_to_postgres() -> None:
default_model_name=default_model_name,
fast_default_model_name=default_fast_model_name,
model_names=None,
display_model_names=[],
is_public=True,
)
llm_provider = upsert_llm_provider(db_session, llm_provider_upsert)
update_default_provider(db_session, llm_provider.id)

View File

@@ -8,7 +8,7 @@ from typing import cast
from filelock import FileLock
from sqlalchemy.orm import Session
from danswer.db.engine import SessionFactory
from danswer.db.engine import get_session_factory
from danswer.db.models import KVStore
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.dynamic_configs.interface import DynamicConfigStore
@@ -56,7 +56,8 @@ class FileSystemBackedDynamicConfigStore(DynamicConfigStore):
class PostgresBackedDynamicConfigStore(DynamicConfigStore):
@contextmanager
def get_session(self) -> Iterator[Session]:
session: Session = SessionFactory()
factory = get_session_factory()
session: Session = factory()
try:
yield session
finally:

View File

@@ -1,12 +1,13 @@
import abc
from collections.abc import Callable
from typing import Optional
from typing import TYPE_CHECKING
from danswer.configs.app_configs import BLURB_SIZE
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
from danswer.configs.app_configs import MINI_CHUNK_SIZE
from danswer.configs.app_configs import SKIP_METADATA_IN_CHUNK
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MAX_CHUNK_TITLE_LEN
from danswer.configs.constants import RETURN_SEPARATOR
from danswer.configs.constants import SECTION_SEPARATOR
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
@@ -14,13 +15,14 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
get_metadata_keys_to_ignore,
)
from danswer.connectors.models import Document
from danswer.indexing.embedder import IndexingEmbedder
from danswer.indexing.models import DocAwareChunk
from danswer.search.search_nlp_models import get_default_tokenizer
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import shared_precompare_cleanup
if TYPE_CHECKING:
from transformers import AutoTokenizer # type:ignore
from llama_index.text_splitter import SentenceSplitter # type:ignore
# Not supporting overlaps, we need a clean combination of chunks and it is unclear if overlaps
@@ -28,6 +30,8 @@ if TYPE_CHECKING:
CHUNK_OVERLAP = 0
# Fairly arbitrary numbers but the general concept is we don't want the title/metadata to
# overwhelm the actual contents of the chunk
# For example in a rare case, this could be 128 tokens for the 512 chunk and title prefix
# could be another 128 tokens leaving 256 for the actual contents
MAX_METADATA_PERCENTAGE = 0.25
CHUNK_MIN_CONTENT = 256
@@ -36,15 +40,11 @@ logger = setup_logger()
ChunkFunc = Callable[[Document], list[DocAwareChunk]]
def extract_blurb(text: str, blurb_size: int) -> str:
from llama_index.text_splitter import SentenceSplitter
token_count_func = get_default_tokenizer().tokenize
blurb_splitter = SentenceSplitter(
tokenizer=token_count_func, chunk_size=blurb_size, chunk_overlap=0
)
return blurb_splitter.split_text(text)[0]
def extract_blurb(text: str, blurb_splitter: "SentenceSplitter") -> str:
texts = blurb_splitter.split_text(text)
if not texts:
return ""
return texts[0]
def chunk_large_section(
@@ -52,76 +52,129 @@ def chunk_large_section(
section_link_text: str,
document: Document,
start_chunk_id: int,
tokenizer: "AutoTokenizer",
chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
chunk_overlap: int = CHUNK_OVERLAP,
blurb_size: int = BLURB_SIZE,
title_prefix: str = "",
metadata_suffix: str = "",
blurb: str,
chunk_splitter: "SentenceSplitter",
mini_chunk_splitter: Optional["SentenceSplitter"],
title_prefix: str,
metadata_suffix_semantic: str,
metadata_suffix_keyword: str,
) -> list[DocAwareChunk]:
from llama_index.text_splitter import SentenceSplitter
blurb = extract_blurb(section_text, blurb_size)
sentence_aware_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize, chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
split_texts = sentence_aware_splitter.split_text(section_text)
split_texts = chunk_splitter.split_text(section_text)
chunks = [
DocAwareChunk(
source_document=document,
chunk_id=start_chunk_id + chunk_ind,
blurb=blurb,
content=f"{title_prefix}{chunk_str}{metadata_suffix}",
content_summary=chunk_str,
content=chunk_text,
source_links={0: section_link_text},
section_continuation=(chunk_ind != 0),
metadata_suffix=metadata_suffix,
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
if mini_chunk_splitter and chunk_text.strip()
else None,
)
for chunk_ind, chunk_str in enumerate(split_texts)
for chunk_ind, chunk_text in enumerate(split_texts)
]
return chunks
def _get_metadata_suffix_for_document_index(
metadata: dict[str, str | list[str]]
) -> str:
metadata: dict[str, str | list[str]], include_separator: bool = False
) -> tuple[str, str]:
"""
Returns the metadata as a natural language string representation with all of the keys and values for the vector embedding
and a string of all of the values for the keyword search
For example, if we have the following metadata:
{
"author": "John Doe",
"space": "Engineering"
}
The vector embedding string should include the relation between the key and value wheres as for keyword we only want John Doe
and Engineering. The keys are repeat and much more noisy.
"""
if not metadata:
return ""
return "", ""
metadata_str = "Metadata:\n"
metadata_values = []
for key, value in metadata.items():
if key in get_metadata_keys_to_ignore():
continue
value_str = ", ".join(value) if isinstance(value, list) else value
if isinstance(value, list):
metadata_values.extend(value)
else:
metadata_values.append(value)
metadata_str += f"\t{key} - {value_str}\n"
return metadata_str.strip()
metadata_semantic = metadata_str.strip()
metadata_keyword = " ".join(metadata_values)
if include_separator:
return RETURN_SEPARATOR + metadata_semantic, RETURN_SEPARATOR + metadata_keyword
return metadata_semantic, metadata_keyword
def chunk_document(
document: Document,
embedder: IndexingEmbedder,
chunk_tok_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
subsection_overlap: int = CHUNK_OVERLAP,
blurb_size: int = BLURB_SIZE,
blurb_size: int = BLURB_SIZE, # Used for both title and content
include_metadata: bool = not SKIP_METADATA_IN_CHUNK,
mini_chunk_size: int = MINI_CHUNK_SIZE,
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
) -> list[DocAwareChunk]:
tokenizer = get_default_tokenizer()
from llama_index.text_splitter import SentenceSplitter
title = document.get_title_for_document_index()
title_prefix = f"{title[:MAX_CHUNK_TITLE_LEN]}{RETURN_SEPARATOR}" if title else ""
tokenizer = get_tokenizer(
model_name=embedder.model_name,
provider_type=embedder.provider_type,
)
blurb_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize, chunk_size=blurb_size, chunk_overlap=0
)
chunk_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize,
chunk_size=chunk_tok_size,
chunk_overlap=subsection_overlap,
)
mini_chunk_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize,
chunk_size=mini_chunk_size,
chunk_overlap=0,
)
title = extract_blurb(document.get_title_for_document_index() or "", blurb_splitter)
title_prefix = title + RETURN_SEPARATOR if title else ""
title_tokens = len(tokenizer.tokenize(title_prefix))
metadata_suffix = ""
metadata_suffix_semantic = ""
metadata_suffix_keyword = ""
metadata_tokens = 0
if include_metadata:
metadata = _get_metadata_suffix_for_document_index(document.metadata)
metadata_suffix = RETURN_SEPARATOR + metadata if metadata else ""
metadata_tokens = len(tokenizer.tokenize(metadata_suffix))
(
metadata_suffix_semantic,
metadata_suffix_keyword,
) = _get_metadata_suffix_for_document_index(
document.metadata, include_separator=True
)
metadata_tokens = len(tokenizer.tokenize(metadata_suffix_semantic))
if metadata_tokens >= chunk_tok_size * MAX_METADATA_PERCENTAGE:
metadata_suffix = ""
# Note: we can keep the keyword suffix even if the semantic suffix is too long to fit in the model
# context, there is no limit for the keyword component
metadata_suffix_semantic = ""
metadata_tokens = 0
content_token_limit = chunk_tok_size - title_tokens - metadata_tokens
@@ -130,7 +183,7 @@ def chunk_document(
if content_token_limit <= CHUNK_MIN_CONTENT:
content_token_limit = chunk_tok_size
title_prefix = ""
metadata_suffix = ""
metadata_suffix_semantic = ""
chunks: list[DocAwareChunk] = []
link_offsets: dict[int, str] = {}
@@ -151,12 +204,16 @@ def chunk_document(
DocAwareChunk(
source_document=document,
chunk_id=len(chunks),
blurb=extract_blurb(chunk_text, blurb_size),
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
content_summary=chunk_text,
blurb=extract_blurb(chunk_text, blurb_splitter),
content=chunk_text,
source_links=link_offsets,
section_continuation=False,
metadata_suffix=metadata_suffix,
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
if enable_mini_chunk and chunk_text.strip()
else None,
)
)
link_offsets = {}
@@ -167,12 +224,14 @@ def chunk_document(
section_link_text=section_link_text,
document=document,
start_chunk_id=len(chunks),
tokenizer=tokenizer,
chunk_size=content_token_limit,
chunk_overlap=subsection_overlap,
blurb_size=blurb_size,
chunk_splitter=chunk_splitter,
mini_chunk_splitter=mini_chunk_splitter
if enable_mini_chunk and chunk_text.strip()
else None,
blurb=extract_blurb(section_text, blurb_splitter),
title_prefix=title_prefix,
metadata_suffix=metadata_suffix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
)
chunks.extend(large_section_chunks)
continue
@@ -193,60 +252,62 @@ def chunk_document(
DocAwareChunk(
source_document=document,
chunk_id=len(chunks),
blurb=extract_blurb(chunk_text, blurb_size),
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
content_summary=chunk_text,
blurb=extract_blurb(chunk_text, blurb_splitter),
content=chunk_text,
source_links=link_offsets,
section_continuation=False,
metadata_suffix=metadata_suffix,
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
if enable_mini_chunk and chunk_text.strip()
else None,
)
)
link_offsets = {0: section_link_text}
chunk_text = section_text
# Once we hit the end, if we're still in the process of building a chunk, add what we have
# NOTE: if it's just whitespace, ignore it.
if chunk_text.strip():
# Once we hit the end, if we're still in the process of building a chunk, add what we have. If there is only whitespace left
# then don't include it. If there are no chunks at all from the doc, we can just create a single chunk with the title.
if chunk_text.strip() or not chunks:
chunks.append(
DocAwareChunk(
source_document=document,
chunk_id=len(chunks),
blurb=extract_blurb(chunk_text, blurb_size),
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
content_summary=chunk_text,
blurb=extract_blurb(chunk_text, blurb_splitter),
content=chunk_text,
source_links=link_offsets,
section_continuation=False,
metadata_suffix=metadata_suffix,
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
if enable_mini_chunk and chunk_text.strip()
else None,
)
)
# If the chunk does not have any useable content, it will not be indexed
return chunks
def split_chunk_text_into_mini_chunks(
chunk_text: str, mini_chunk_size: int = MINI_CHUNK_SIZE
) -> list[str]:
"""The minichunks won't all have the title prefix or metadata suffix
It could be a significant percentage of every minichunk so better to not include it
"""
from llama_index.text_splitter import SentenceSplitter
token_count_func = get_default_tokenizer().tokenize
sentence_aware_splitter = SentenceSplitter(
tokenizer=token_count_func, chunk_size=mini_chunk_size, chunk_overlap=0
)
return sentence_aware_splitter.split_text(chunk_text)
class Chunker:
@abc.abstractmethod
def chunk(self, document: Document) -> list[DocAwareChunk]:
def chunk(
self,
document: Document,
embedder: IndexingEmbedder,
) -> list[DocAwareChunk]:
raise NotImplementedError
class DefaultChunker(Chunker):
def chunk(self, document: Document) -> list[DocAwareChunk]:
def chunk(
self,
document: Document,
embedder: IndexingEmbedder,
) -> list[DocAwareChunk]:
# Specifically for reproducing an issue with gmail
if document.source == DocumentSource.GMAIL:
logger.debug(f"Chunking {document.semantic_identifier}")
return chunk_document(document)
return chunk_document(document, embedder=embedder)

View File

@@ -3,23 +3,20 @@ from abc import abstractmethod
from sqlalchemy.orm import Session
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.models import EmbeddingModel as DbEmbeddingModel
from danswer.db.models import IndexModelStatus
from danswer.indexing.chunker import split_chunk_text_into_mini_chunks
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.utils.batching import batch_list
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.utils.logger import setup_logger
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import Embedding
logger = setup_logger()
@@ -32,14 +29,21 @@ class IndexingEmbedder(ABC):
normalize: bool,
query_prefix: str | None,
passage_prefix: str | None,
provider_type: str | None,
api_key: str | None,
):
self.model_name = model_name
self.normalize = normalize
self.query_prefix = query_prefix
self.passage_prefix = passage_prefix
self.provider_type = provider_type
self.api_key = api_key
@abstractmethod
def embed_chunks(self, chunks: list[DocAwareChunk]) -> list[IndexChunk]:
def embed_chunks(
self,
chunks: list[DocAwareChunk],
) -> list[IndexChunk]:
raise NotImplementedError
@@ -50,10 +54,12 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
normalize: bool,
query_prefix: str | None,
passage_prefix: str | None,
api_key: str | None = None,
provider_type: str | None = None,
api_key: str | None = None,
):
super().__init__(model_name, normalize, query_prefix, passage_prefix)
super().__init__(
model_name, normalize, query_prefix, passage_prefix, provider_type, api_key
)
self.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE # Currently not customizable
self.embedding_model = EmbeddingModel(
@@ -66,72 +72,63 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
# The below are globally set, this flow always uses the indexing one
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
retrim_content=True,
)
def embed_chunks(
self,
chunks: list[DocAwareChunk],
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
) -> list[IndexChunk]:
# Cache the Title embeddings to only have to do it once
title_embed_dict: dict[str, list[float]] = {}
embedded_chunks: list[IndexChunk] = []
# All chunks at this point must have some non-empty content
flat_chunk_texts: list[str] = []
for chunk in chunks:
chunk_text = (
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_semantic}"
) or chunk.source_document.get_title_for_document_index()
# Create Mini Chunks for more precise matching of details
# Off by default with unedited settings
chunk_texts = []
chunk_mini_chunks_count = {}
for chunk_ind, chunk in enumerate(chunks):
chunk_texts.append(chunk.content)
mini_chunk_texts = (
split_chunk_text_into_mini_chunks(chunk.content_summary)
if enable_mini_chunk
else []
)
chunk_texts.extend(mini_chunk_texts)
chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts)
if not chunk_text:
# This should never happen, the document would have been dropped
# before getting to this point
raise ValueError(f"Chunk has no content: {chunk.to_short_descriptor()}")
# Batching for embedding
text_batches = batch_list(chunk_texts, batch_size)
flat_chunk_texts.append(chunk_text)
embeddings: list[list[float]] = []
len_text_batches = len(text_batches)
for idx, text_batch in enumerate(text_batches, start=1):
logger.debug(f"Embedding Content Texts batch {idx} of {len_text_batches}")
# Normalize embeddings is only configured via model_configs.py, be sure to use right
# value for the set loss
embeddings.extend(
self.embedding_model.encode(text_batch, text_type=EmbedTextType.PASSAGE)
)
if chunk.mini_chunk_texts:
flat_chunk_texts.extend(chunk.mini_chunk_texts)
# Replace line above with the line below for easy debugging of indexing flow
# skipping the actual model
# embeddings.extend([[0.0] * 384 for _ in range(len(text_batch))])
embeddings = self.embedding_model.encode(
flat_chunk_texts, text_type=EmbedTextType.PASSAGE
)
chunk_titles = {
chunk.source_document.get_title_for_document_index() for chunk in chunks
}
# Drop any None or empty strings
# If there is no title or the title is empty, the title embedding field will be null
# which is ok, it just won't contribute at all to the scoring.
chunk_titles_list = [title for title in chunk_titles if title]
# Embed Titles in batches
title_batches = batch_list(chunk_titles_list, batch_size)
len_title_batches = len(title_batches)
for ind_batch, title_batch in enumerate(title_batches, start=1):
logger.debug(f"Embedding Titles batch {ind_batch} of {len_title_batches}")
# Cache the Title embeddings to only have to do it once
title_embed_dict: dict[str, Embedding] = {}
if chunk_titles_list:
title_embeddings = self.embedding_model.encode(
title_batch, text_type=EmbedTextType.PASSAGE
chunk_titles_list, text_type=EmbedTextType.PASSAGE
)
title_embed_dict.update(
{title: vector for title, vector in zip(title_batch, title_embeddings)}
{
title: vector
for title, vector in zip(chunk_titles_list, title_embeddings)
}
)
# Mapping embeddings to chunks
embedded_chunks: list[IndexChunk] = []
embedding_ind_start = 0
for chunk_ind, chunk in enumerate(chunks):
num_embeddings = chunk_mini_chunks_count[chunk_ind]
for chunk in chunks:
num_embeddings = 1 + (
len(chunk.mini_chunk_texts) if chunk.mini_chunk_texts else 0
)
chunk_embeddings = embeddings[
embedding_ind_start : embedding_ind_start + num_embeddings
]
@@ -184,4 +181,6 @@ def get_embedding_model_from_db_embedding_model(
normalize=db_embedding_model.normalize,
query_prefix=db_embedding_model.query_prefix,
passage_prefix=db_embedding_model.passage_prefix,
provider_type=db_embedding_model.provider_type,
api_key=db_embedding_model.api_key,
)

View File

@@ -1,5 +1,4 @@
from functools import partial
from itertools import chain
from typing import Protocol
from sqlalchemy.orm import Session
@@ -34,7 +33,9 @@ logger = setup_logger()
class IndexingPipelineProtocol(Protocol):
def __call__(
self, documents: list[Document], index_attempt_metadata: IndexAttemptMetadata
self,
document_batch: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
) -> tuple[int, int]:
...
@@ -116,7 +117,7 @@ def index_doc_batch(
chunker: Chunker,
embedder: IndexingEmbedder,
document_index: DocumentIndex,
documents: list[Document],
document_batch: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
db_session: Session,
ignore_time_skip: bool = False,
@@ -124,6 +125,32 @@ def index_doc_batch(
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
Note that the documents should already be batched at this point so that it does not inflate the
memory requirements"""
documents = []
for document in document_batch:
empty_contents = not any(section.text.strip() for section in document.sections)
if (
(not document.title or not document.title.strip())
and not document.semantic_identifier.strip()
and empty_contents
):
# Skip documents that have neither title nor content
# If the document doesn't have either, then there is no useful information in it
# This is again verified later in the pipeline after chunking but at that point there should
# already be no documents that are empty.
logger.warning(
f"Skipping document with ID {document.id} as it has neither title nor content."
)
elif (
document.title is not None and not document.title.strip() and empty_contents
):
# The title is explicitly empty ("" and not None) and the document is empty
# so when building the chunk text representation, it will be empty and unuseable
logger.warning(
f"Skipping document with ID {document.id} as the chunks will be empty."
)
else:
documents.append(document)
document_ids = [document.id for document in documents]
db_docs = get_documents_by_ids(
document_ids=document_ids,
@@ -138,6 +165,11 @@ def index_doc_batch(
if not ignore_time_skip
else documents
)
# No docs to update either because the batch is empty or every doc was already indexed
if not updatable_docs:
return 0, 0
updatable_ids = [doc.id for doc in updatable_docs]
# Create records in the source of truth about these documents,
@@ -149,14 +181,21 @@ def index_doc_batch(
)
logger.debug("Starting chunking")
# The first chunk additionally contains the Title of the Document
chunks: list[DocAwareChunk] = list(
chain(*[chunker.chunk(document=document) for document in updatable_docs])
)
# The embedder is needed here to get the correct tokenizer
chunks: list[DocAwareChunk] = [
chunk
for document in updatable_docs
for chunk in chunker.chunk(document=document, embedder=embedder)
]
logger.debug("Starting embedding")
chunks_with_embeddings = embedder.embed_chunks(chunks=chunks)
chunks_with_embeddings = (
embedder.embed_chunks(
chunks=chunks,
)
if chunks
else []
)
# Acquires a lock on the documents so that no other process can modify them
# NOTE: don't need to acquire till here, since this is when the actual race condition
@@ -191,7 +230,7 @@ def index_doc_batch(
]
logger.debug(
f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in chunks]}"
f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in access_aware_chunks]}"
)
# A document will not be spread across different batches, so all the
# documents with chunks in this set, are fully represented by the chunks
@@ -215,7 +254,7 @@ def index_doc_batch(
)
return len([r for r in insertion_records if r.already_existed is False]), len(
chunks
access_aware_chunks
)

View File

@@ -5,6 +5,7 @@ from pydantic import BaseModel
from danswer.access.models import DocumentAccess
from danswer.connectors.models import Document
from danswer.utils.logger import setup_logger
from shared_configs.model_server_models import Embedding
if TYPE_CHECKING:
from danswer.db.models import EmbeddingModel
@@ -13,9 +14,6 @@ if TYPE_CHECKING:
logger = setup_logger()
Embedding = list[float]
class ChunkEmbedding(BaseModel):
full_embedding: Embedding
mini_chunk_embeddings: list[Embedding]
@@ -36,15 +34,17 @@ class DocAwareChunk(BaseChunk):
# During inference we only have access to the document id and do not reconstruct the Document
source_document: Document
# The Vespa documents require a separate highlight field. Since it is stored as a duplicate anyway,
# it's easier to just store a not prefixed/suffixed string for the highlighting
# Also during the chunking, this non-prefixed/suffixed string is used for mini-chunks
content_summary: str
# This could be an empty string if the title is too long and taking up too much of the chunk
# This does not mean necessarily that the document does not have a title
title_prefix: str
# During indexing we also (optionally) build a metadata string from the metadata dict
# This is also indexed so that we can strip it out after indexing, this way it supports
# multiple iterations of metadata representation for backwards compatibility
metadata_suffix: str
metadata_suffix_semantic: str
metadata_suffix_keyword: str
mini_chunk_texts: list[str] | None
def to_short_descriptor(self) -> str:
"""Used when logging the identity of a chunk"""

View File

@@ -34,8 +34,8 @@ from danswer.llm.answering.stream_processing.quotes_processing import (
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.llm.answering.stream_processing.utils import map_document_id_order
from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import message_generator_to_string_generator
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.tools.custom.custom_tool_prompt_builder import (
build_user_message_for_custom_tool_for_non_tool_calling_llm,
)
@@ -99,6 +99,7 @@ class Answer:
answer_style_config: AnswerStyleConfig,
llm: LLM,
prompt_config: PromptConfig,
force_use_tool: ForceUseTool,
# must be the same length as `docs`. If None, all docs are considered "relevant"
message_history: list[PreviousMessage] | None = None,
single_message_history: str | None = None,
@@ -107,10 +108,8 @@ class Answer:
latest_query_files: list[InMemoryChatFile] | None = None,
files: list[InMemoryChatFile] | None = None,
tools: list[Tool] | None = None,
# if specified, tells the LLM to always this tool
# NOTE: for native tool-calling, this is only supported by OpenAI atm,
# but we only support them anyways
force_use_tool: ForceUseTool | None = None,
# if set to True, then never use the LLMs provided tool-calling functonality
skip_explicit_tool_calling: bool = False,
# Returns the full document sections text from the search tool
@@ -129,6 +128,7 @@ class Answer:
self.tools = tools or []
self.force_use_tool = force_use_tool
self.skip_explicit_tool_calling = skip_explicit_tool_calling
self.message_history = message_history or []
@@ -139,7 +139,10 @@ class Answer:
self.prompt_config = prompt_config
self.llm = llm
self.llm_tokenizer = get_default_llm_tokenizer()
self.llm_tokenizer = get_tokenizer(
provider_type=llm.config.model_provider,
model_name=llm.config.model_name,
)
self._final_prompt: list[BaseMessage] | None = None
@@ -187,7 +190,7 @@ class Answer:
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
tool_call_chunk: AIMessageChunk | None = None
if self.force_use_tool and self.force_use_tool.args is not None:
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
# if we are forcing a tool WITH args specified, we don't need to check which tools to run
# / need to generate the args
tool_call_chunk = AIMessageChunk(
@@ -221,7 +224,7 @@ class Answer:
for message in self.llm.stream(
prompt=prompt,
tools=final_tool_definitions if final_tool_definitions else None,
tool_choice="required" if self.force_use_tool else None,
tool_choice="required" if self.force_use_tool.force_use else None,
):
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
@@ -240,12 +243,26 @@ class Answer:
# if we have a tool call, we need to call the tool
tool_call_requests = tool_call_chunk.tool_calls
for tool_call_request in tool_call_requests:
tool = [
known_tools_by_name = [
tool for tool in self.tools if tool.name == tool_call_request["name"]
][0]
]
if not known_tools_by_name:
logger.error(
"Tool call requested with unknown name field. \n"
f"self.tools: {self.tools}"
f"tool_call_request: {tool_call_request}"
)
if self.tools:
tool = self.tools[0]
else:
continue
else:
tool = known_tools_by_name[0]
tool_args = (
self.force_use_tool.args
if self.force_use_tool and self.force_use_tool.args
if self.force_use_tool.tool_name == tool.name
and self.force_use_tool.args
else tool_call_request["args"]
)
@@ -263,9 +280,13 @@ class Answer:
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
self._update_prompt_builder_for_search_tool(prompt_builder, [])
elif tool.name == ImageGenerationTool._NAME:
img_urls = [
img_generation_result["url"]
for img_generation_result in tool_runner.tool_final_result().tool_result
]
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question,
query=self.question, img_urls=img_urls
)
)
yield tool_runner.tool_final_result()
@@ -286,7 +307,7 @@ class Answer:
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
chosen_tool_and_args: tuple[Tool, dict] | None = None
if self.force_use_tool:
if self.force_use_tool.force_use:
# if we are forcing a tool, we don't need to check which tools to run
tool = next(
iter(
@@ -303,7 +324,7 @@ class Answer:
tool_args = (
self.force_use_tool.args
if self.force_use_tool.args
if self.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=self.question,
history=self.message_history,
@@ -462,6 +483,7 @@ class Answer:
]
elif message.id == FINAL_CONTEXT_DOCUMENTS:
final_context_docs = cast(list[LlmDoc], message.response)
elif (
message.id == SEARCH_DOC_CONTENT_ID
and not self._return_contexts

View File

@@ -16,6 +16,7 @@ from danswer.configs.constants import MessageType
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.override_models import PromptOverride
from danswer.llm.utils import build_content_with_imgs
from danswer.tools.models import ToolCallFinalResult
if TYPE_CHECKING:
from danswer.db.models import ChatMessage
@@ -32,6 +33,7 @@ class PreviousMessage(BaseModel):
token_count: int
message_type: MessageType
files: list[InMemoryChatFile]
tool_calls: list[ToolCallFinalResult]
@classmethod
def from_chat_message(
@@ -49,6 +51,14 @@ class PreviousMessage(BaseModel):
for file in available_files
if str(file.file_id) in message_file_ids
],
tool_calls=[
ToolCallFinalResult(
tool_name=tool_call.tool_name,
tool_args=tool_call.tool_arguments,
tool_result=tool_call.tool_result,
)
for tool_call in chat_message.tool_calls
],
)
def to_langchain_msg(self) -> BaseMessage:

View File

@@ -12,8 +12,8 @@ from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import check_message_tokens
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import translate_history_to_basemessages
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from danswer.prompts.prompt_utils import add_date_time_to_prompt
from danswer.prompts.prompt_utils import drop_messages_history_overflow
@@ -66,7 +66,10 @@ class AnswerPromptBuilder:
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
self.user_message_and_token_cnt: tuple[HumanMessage, int] | None = None
llm_tokenizer = get_default_llm_tokenizer()
llm_tokenizer = get_tokenizer(
provider_type=llm_config.model_provider,
model_name=llm_config.model_name,
)
self.llm_tokenizer_encode_func = cast(
Callable[[str], list[int]], llm_tokenizer.encode
)
@@ -111,8 +114,24 @@ class AnswerPromptBuilder:
final_messages_with_tokens.append(self.user_message_and_token_cnt)
if tool_call_summary:
final_messages_with_tokens.append((tool_call_summary.tool_call_request, 0))
final_messages_with_tokens.append((tool_call_summary.tool_call_result, 0))
final_messages_with_tokens.append(
(
tool_call_summary.tool_call_request,
check_message_tokens(
tool_call_summary.tool_call_request,
self.llm_tokenizer_encode_func,
),
)
)
final_messages_with_tokens.append(
(
tool_call_summary.tool_call_result,
check_message_tokens(
tool_call_summary.tool_call_result,
self.llm_tokenizer_encode_func,
),
)
)
return drop_messages_history_overflow(
final_messages_with_tokens, self.max_tokens

View File

@@ -14,8 +14,8 @@ from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import tokenizer_trim_content
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.prompts.prompt_utils import build_doc_context_str
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
@@ -28,6 +28,9 @@ logger = setup_logger()
T = TypeVar("T", bound=LlmDoc | InferenceChunk | InferenceSection)
_METADATA_TOKEN_ESTIMATE = 75
# Title and additional tokens as part of the tool message json
# this is only used to log a warning so we can be more forgiving with the buffer
_OVERCOUNT_ESTIMATE = 256
class PruningError(Exception):
@@ -135,8 +138,12 @@ def _apply_pruning(
is_manually_selected_docs: bool,
use_sections: bool,
using_tool_message: bool,
llm_config: LLMConfig,
) -> list[InferenceSection]:
llm_tokenizer = get_default_llm_tokenizer()
llm_tokenizer = get_tokenizer(
provider_type=llm_config.model_provider,
model_name=llm_config.model_name,
)
sections = deepcopy(sections) # don't modify in place
# re-order docs with all the "relevant" docs at the front
@@ -165,27 +172,36 @@ def _apply_pruning(
)
)
section_tokens = len(llm_tokenizer.encode(section_str))
section_token_count = len(llm_tokenizer.encode(section_str))
# if not using sections (specifically, using Sections where each section maps exactly to the one center chunk),
# truncate chunks that are way too long. This can happen if the embedding model tokenizer is different
# than the LLM tokenizer
if (
not is_manually_selected_docs
and not use_sections
and section_tokens > DOC_EMBEDDING_CONTEXT_SIZE + _METADATA_TOKEN_ESTIMATE
and section_token_count
> DOC_EMBEDDING_CONTEXT_SIZE + _METADATA_TOKEN_ESTIMATE
):
logger.warning(
"Found more tokens in Section than expected, "
"likely mismatch between embedding and LLM tokenizers. Trimming content..."
)
if (
section_token_count
> DOC_EMBEDDING_CONTEXT_SIZE
+ _METADATA_TOKEN_ESTIMATE
+ _OVERCOUNT_ESTIMATE
):
# If the section is just a little bit over, it is likely due to the additional tool message tokens
# no need to record this, the content will be trimmed just in case
logger.info(
"Found more tokens in Section than expected, "
"likely mismatch between embedding and LLM tokenizers. Trimming content..."
)
section.combined_content = tokenizer_trim_content(
content=section.combined_content,
desired_length=DOC_EMBEDDING_CONTEXT_SIZE,
tokenizer=llm_tokenizer,
)
section_tokens = DOC_EMBEDDING_CONTEXT_SIZE
section_token_count = DOC_EMBEDDING_CONTEXT_SIZE
total_tokens += section_tokens
total_tokens += section_token_count
if total_tokens > token_limit:
final_section_ind = ind
break
@@ -273,6 +289,7 @@ def prune_sections(
is_manually_selected_docs=document_pruning_config.is_manually_selected_docs,
use_sections=document_pruning_config.use_sections, # Now default True
using_tool_message=document_pruning_config.using_tool_message,
llm_config=llm_config,
)

View File

@@ -17,7 +17,6 @@ from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.prompts.constants import ANSWER_PAT
from danswer.prompts.constants import QUOTE_PAT
from danswer.prompts.constants import UNCERTAINTY_PAT
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_model_quote
@@ -27,6 +26,7 @@ from danswer.utils.text_processing import shared_precompare_cleanup
logger = setup_logger()
answer_pattern = re.compile(r'{\s*"answer"\s*:\s*"', re.IGNORECASE)
def _extract_answer_quotes_freeform(
@@ -166,11 +166,8 @@ def process_answer(
into an Answer and Quotes AND (2) after the complete streaming response
has been received to process the model output into an Answer and Quotes."""
answer, quote_strings = separate_answer_quotes(answer_raw, is_json_prompt)
if answer == UNCERTAINTY_PAT or not answer:
if answer == UNCERTAINTY_PAT:
logger.debug("Answer matched UNCERTAINTY_PAT")
else:
logger.debug("No answer extracted from raw output")
if not answer:
logger.debug("No answer extracted from raw output")
return DanswerAnswer(answer=None), DanswerQuotes(quotes=[])
logger.info(f"Answer: {answer}")
@@ -227,22 +224,26 @@ def process_model_tokens(
found_answer_start = False if is_json_prompt else True
found_answer_end = False
hold_quote = ""
for token in tokens:
model_previous = model_output
model_output += token
if not found_answer_start and '{"answer":"' in re.sub(r"\s", "", model_output):
# Note, if the token that completes the pattern has additional text, for example if the token is "?
# Then the chars after " will not be streamed, but this is ok as it prevents streaming the ? in the
# event that the model outputs the UNCERTAINTY_PAT
found_answer_start = True
if not found_answer_start:
m = answer_pattern.search(model_output)
if m:
found_answer_start = True
# Prevent heavy cases of hallucinations where model is not even providing a json until later
if is_json_prompt and len(model_output) > 40:
logger.warning("LLM did not produce json as prompted")
found_answer_end = True
# Prevent heavy cases of hallucinations where model is not even providing a json until later
if is_json_prompt and len(model_output) > 40:
logger.warning("LLM did not produce json as prompted")
found_answer_end = True
continue
continue
remaining = model_output[m.end() :]
if len(remaining) > 0:
yield DanswerAnswerPiece(answer_piece=remaining)
continue
if found_answer_start and not found_answer_end:
if is_json_prompt and _stream_json_answer_end(model_previous, token):

View File

@@ -266,8 +266,14 @@ class DefaultMultiLLM(LLM):
stream=stream,
# model params
temperature=self._temperature,
max_tokens=self._max_output_tokens,
max_tokens=self._max_output_tokens
if self._max_output_tokens > 0
else None,
timeout=self._timeout,
# For now, we don't support parallel tool calls
# NOTE: we can't pass this in if tools are not specified
# or else OpenAI throws an error
**({"parallel_tool_calls": False} if tools else {}),
**self._model_kwargs,
)
except Exception as e:

View File

@@ -70,6 +70,8 @@ def load_llm_providers(db_session: Session) -> None:
FAST_GEN_AI_MODEL_VERSION or well_known_provider.default_fast_model
),
model_names=model_names,
is_public=True,
display_model_names=[],
)
llm_provider = upsert_llm_provider(db_session, llm_provider_request)
update_default_provider(db_session, llm_provider.id)

View File

@@ -31,9 +31,7 @@ OPEN_AI_MODEL_NAMES = [
"gpt-4-turbo-preview",
"gpt-4-1106-preview",
"gpt-4-vision-preview",
# "gpt-4-32k", # not EOL but still doesnt work
"gpt-4-0613",
# "gpt-4-32k-0613", # not EOL but still doesnt work
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-3.5-turbo",
@@ -48,9 +46,11 @@ OPEN_AI_MODEL_NAMES = [
BEDROCK_PROVIDER_NAME = "bedrock"
# need to remove all the weird "bedrock/eu-central-1/anthropic.claude-v1" named
# models
BEDROCK_MODEL_NAMES = [model for model in litellm.bedrock_models if "/" not in model][
::-1
]
BEDROCK_MODEL_NAMES = [
model
for model in litellm.bedrock_models
if "/" not in model and "embed" not in model
][::-1]
IGNORABLE_ANTHROPIC_MODELS = [
"claude-2",
@@ -84,7 +84,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
custom_config_keys=[],
llm_names=fetch_models_for_provider(OPENAI_PROVIDER_NAME),
default_model="gpt-4",
default_fast_model="gpt-3.5-turbo",
default_fast_model="gpt-4o-mini",
),
WellKnownLLMProviderDescriptor(
name=ANTHROPIC_PROVIDER_NAME,

View File

@@ -1,6 +1,6 @@
import json
from collections.abc import Callable
from collections.abc import Iterator
from copy import copy
from typing import Any
from typing import cast
from typing import TYPE_CHECKING
@@ -16,10 +16,8 @@ from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from tiktoken.core import Encoding
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
@@ -28,7 +26,6 @@ from danswer.file_store.models import ChatFileType
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.interfaces import LLM
from danswer.prompts.constants import CODE_BLOCK_PAT
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
from shared_configs.configs import LOG_LEVEL
@@ -37,60 +34,17 @@ if TYPE_CHECKING:
logger = setup_logger()
_LLM_TOKENIZER: Any = None
_LLM_TOKENIZER_ENCODE: Callable[[str], Any] | None = None
def get_default_llm_tokenizer() -> Encoding:
"""Currently only supports the OpenAI default tokenizer: tiktoken"""
global _LLM_TOKENIZER
if _LLM_TOKENIZER is None:
_LLM_TOKENIZER = tiktoken.get_encoding("cl100k_base")
return _LLM_TOKENIZER
def get_default_llm_token_encode() -> Callable[[str], Any]:
global _LLM_TOKENIZER_ENCODE
if _LLM_TOKENIZER_ENCODE is None:
tokenizer = get_default_llm_tokenizer()
if isinstance(tokenizer, Encoding):
return tokenizer.encode # type: ignore
# Currently only supports OpenAI encoder
raise ValueError("Invalid Encoder selected")
return _LLM_TOKENIZER_ENCODE
def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: Encoding
) -> str:
tokens = tokenizer.encode(content)
if len(tokens) > desired_length:
content = tokenizer.decode(tokens[:desired_length])
return content
def tokenizer_trim_chunks(
chunks: list[InferenceChunk], max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE
) -> list[InferenceChunk]:
tokenizer = get_default_llm_tokenizer()
new_chunks = copy(chunks)
for ind, chunk in enumerate(new_chunks):
new_content = tokenizer_trim_content(chunk.content, max_chunk_toks, tokenizer)
if len(new_content) != len(chunk.content):
new_chunk = copy(chunk)
new_chunk.content = new_content
new_chunks[ind] = new_chunk
return new_chunks
def translate_danswer_msg_to_langchain(
msg: Union[ChatMessage, "PreviousMessage"],
) -> BaseMessage:
files: list[InMemoryChatFile] = []
# If the message is a `ChatMessage`, it doesn't have the downloaded files
# attached. Just ignore them for now
files = [] if isinstance(msg, ChatMessage) else msg.files
# attached. Just ignore them for now. Also, OpenAI doesn't allow files to
# be attached to AI messages, so we must remove them
if not isinstance(msg, ChatMessage) and msg.message_type != MessageType.ASSISTANT:
files = msg.files
content = build_content_with_imgs(msg.message, files)
if msg.message_type == MessageType.SYSTEM:
@@ -271,6 +225,13 @@ def check_message_tokens(
elif part["type"] == "image_url":
total_tokens += _IMG_TOKENS
if isinstance(message, AIMessage) and message.tool_calls:
for tool_call in message.tool_calls:
total_tokens += check_number_of_tokens(
json.dumps(tool_call["args"]), encode_fn
)
total_tokens += check_number_of_tokens(tool_call["name"], encode_fn)
return total_tokens

View File

@@ -34,6 +34,7 @@ from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.constants import AuthType
from danswer.configs.constants import POSTGRES_WEB_APP_NAME
from danswer.db.connector import create_initial_default_connector
from danswer.db.connector_credential_pair import associate_default_cc_pair
from danswer.db.connector_credential_pair import get_connector_credential_pairs
@@ -42,16 +43,19 @@ from danswer.db.credentials import create_initial_public_credential
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import init_sqlalchemy_engine
from danswer.db.engine import warm_up_connections
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import expire_index_attempts
from danswer.db.models import EmbeddingModel
from danswer.db.persona import delete_old_default_personas
from danswer.db.standard_answer import create_initial_default_standard_answer_category
from danswer.db.swap_index import check_index_swap
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import DocumentIndex
from danswer.llm.llm_initialization import load_llm_providers
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.server.auth_check import check_router_auth
from danswer.server.danswer_api.ingestion import router as danswer_api_router
from danswer.server.documents.cc_pair import router as cc_pair_router
@@ -60,6 +64,10 @@ from danswer.server.documents.credential import router as credential_router
from danswer.server.documents.document import router as document_router
from danswer.server.features.document_set.api import router as document_set_router
from danswer.server.features.folder.api import router as folder_router
from danswer.server.features.input_prompt.api import (
admin_router as admin_input_prompt_router,
)
from danswer.server.features.input_prompt.api import basic_router as input_prompt_router
from danswer.server.features.persona.api import admin_router as admin_persona_router
from danswer.server.features.persona.api import basic_router as persona_router
from danswer.server.features.prompt.api import basic_router as prompt_router
@@ -152,8 +160,52 @@ def include_router_with_global_prefix_prepended(
application.include_router(router, **final_kwargs)
def setup_postgres(db_session: Session) -> None:
logger.info("Verifying default connector/credential exist.")
create_initial_public_credential(db_session)
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
logger.info("Verifying default standard answer category exists.")
create_initial_default_standard_answer_category(db_session)
logger.info("Loading LLM providers from env variables")
load_llm_providers(db_session)
logger.info("Loading default Prompts and Personas")
delete_old_default_personas(db_session)
load_chat_yamls()
logger.info("Loading built-in tools")
load_builtin_tools(db_session)
refresh_built_in_tools_cache(db_session)
auto_add_search_tool_to_personas(db_session)
def setup_vespa(
document_index: DocumentIndex,
db_embedding_model: EmbeddingModel,
secondary_db_embedding_model: EmbeddingModel | None,
) -> None:
# Vespa startup is a bit slow, so give it a few seconds
wait_time = 5
for _ in range(5):
try:
document_index.ensure_indices_exist(
index_embedding_dim=db_embedding_model.model_dim,
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
if secondary_db_embedding_model
else None,
)
break
except Exception:
logger.info(f"Waiting on Vespa, retrying in {wait_time} seconds...")
time.sleep(wait_time)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
init_sqlalchemy_engine(POSTGRES_WEB_APP_NAME)
engine = get_sqlalchemy_engine()
verify_auth = fetch_versioned_implementation(
@@ -206,26 +258,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
download_nltk_data()
logger.info("Verifying default connector/credential exist.")
create_initial_public_credential(db_session)
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
logger.info("Verifying default standard answer category exists.")
create_initial_default_standard_answer_category(db_session)
logger.info("Loading LLM providers from env variables")
load_llm_providers(db_session)
logger.info("Loading default Prompts and Personas")
delete_old_default_personas(db_session)
load_chat_yamls()
logger.info("Loading built-in tools")
load_builtin_tools(db_session)
refresh_built_in_tools_cache(db_session)
auto_add_search_tool_to_personas(db_session)
# setup Postgres with default credential, llm providers, etc.
setup_postgres(db_session)
# ensure Vespa is setup correctly
logger.info("Verifying Document Index(s) is/are available.")
document_index = get_default_document_index(
primary_index_name=db_embedding_model.index_name,
@@ -233,29 +269,15 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
if secondary_db_embedding_model
else None,
)
# Vespa startup is a bit slow, so give it a few seconds
wait_time = 5
for attempt in range(5):
try:
document_index.ensure_indices_exist(
index_embedding_dim=db_embedding_model.model_dim,
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
if secondary_db_embedding_model
else None,
)
break
except Exception:
logger.info(f"Waiting on Vespa, retrying in {wait_time} seconds...")
time.sleep(wait_time)
setup_vespa(document_index, db_embedding_model, secondary_db_embedding_model)
logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
if db_embedding_model.cloud_provider_id is None:
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
if db_embedding_model.cloud_provider_id is None:
warm_up_encoders(
embedding_model=db_embedding_model,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
yield
@@ -284,6 +306,8 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, standard_answer_router)
include_router_with_global_prefix_prepended(application, persona_router)
include_router_with_global_prefix_prepended(application, admin_persona_router)
include_router_with_global_prefix_prepended(application, input_prompt_router)
include_router_with_global_prefix_prepended(application, admin_input_prompt_router)
include_router_with_global_prefix_prepended(application, prompt_router)
include_router_with_global_prefix_prepended(application, tool_router)
include_router_with_global_prefix_prepended(application, admin_tool_router)

View File

@@ -0,0 +1,279 @@
import time
import requests
from httpx import HTTPError
from retry import retry
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
from danswer.configs.model_configs import (
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES,
)
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.models import EmbeddingModel as DBEmbeddingModel
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import Embedding
from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
from shared_configs.model_server_models import RerankRequest
from shared_configs.model_server_models import RerankResponse
from shared_configs.utils import batch_list
logger = setup_logger()
def clean_model_name(model_str: str) -> str:
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
def build_model_server_url(
model_server_host: str,
model_server_port: int,
) -> str:
model_server_url = f"{model_server_host}:{model_server_port}"
# use protocol if provided
if "http" in model_server_url:
return model_server_url
# otherwise default to http
return f"http://{model_server_url}"
class EmbeddingModel:
def __init__(
self,
server_host: str, # Changes depending on indexing or inference
server_port: int,
model_name: str | None,
normalize: bool,
query_prefix: str | None,
passage_prefix: str | None,
api_key: str | None,
provider_type: str | None,
# The following are globals are currently not configurable
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
retrim_content: bool = False,
) -> None:
self.api_key = api_key
self.provider_type = provider_type
self.max_seq_length = max_seq_length
self.query_prefix = query_prefix
self.passage_prefix = passage_prefix
self.normalize = normalize
self.model_name = model_name
self.retrim_content = retrim_content
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
def _make_request() -> EmbedResponse:
response = requests.post(
self.embed_server_endpoint, json=embed_request.dict()
)
try:
response.raise_for_status()
except requests.HTTPError as e:
try:
error_detail = response.json().get("detail", str(e))
except Exception:
error_detail = response.text
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
return EmbedResponse(**response.json())
# only perform retries for the non-realtime embedding of passages (e.g. for indexing)
if embed_request.text_type == EmbedTextType.PASSAGE:
return retry(tries=3, delay=5)(_make_request)()
else:
return _make_request()
def _encode_api_model(
self, texts: list[str], text_type: EmbedTextType, batch_size: int
) -> list[Embedding]:
if not self.provider_type:
raise ValueError("Provider type is not set for API embedding")
embeddings: list[Embedding] = []
text_batches = batch_list(texts, batch_size)
for idx, text_batch in enumerate(text_batches, start=1):
logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
embed_request = EmbedRequest(
model_name=self.model_name,
texts=text_batch,
max_context_length=self.max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
provider_type=self.provider_type,
text_type=text_type,
manual_query_prefix=self.query_prefix,
manual_passage_prefix=self.passage_prefix,
)
response = self._make_model_server_request(embed_request)
embeddings.extend(response.embeddings)
return embeddings
def _encode_local_model(
self,
texts: list[str],
text_type: EmbedTextType,
batch_size: int,
) -> list[Embedding]:
text_batches = batch_list(texts, batch_size)
embeddings: list[Embedding] = []
logger.debug(
f"Encoding {len(texts)} texts in {len(text_batches)} batches for local model"
)
for idx, text_batch in enumerate(text_batches, start=1):
logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
embed_request = EmbedRequest(
model_name=self.model_name,
texts=text_batch,
max_context_length=self.max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
provider_type=self.provider_type,
text_type=text_type,
manual_query_prefix=self.query_prefix,
manual_passage_prefix=self.passage_prefix,
)
response = self._make_model_server_request(embed_request)
embeddings.extend(response.embeddings)
return embeddings
def encode(
self,
texts: list[str],
text_type: EmbedTextType,
local_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
api_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES,
) -> list[Embedding]:
if not texts or not all(texts):
raise ValueError(f"Empty or missing text for embedding: {texts}")
if self.retrim_content:
# This is applied during indexing as a catchall for overly long titles (or other uncapped fields)
# Note that this uses just the default tokenizer which may also lead to very minor miscountings
# However this slight miscounting is very unlikely to have any material impact.
texts = [
tokenizer_trim_content(
content=text,
desired_length=self.max_seq_length,
tokenizer=get_tokenizer(
model_name=self.model_name,
provider_type=self.provider_type,
),
)
for text in texts
]
if self.provider_type:
return self._encode_api_model(
texts=texts, text_type=text_type, batch_size=api_embedding_batch_size
)
# if no provider, use local model
return self._encode_local_model(
texts=texts, text_type=text_type, batch_size=local_embedding_batch_size
)
class CrossEncoderEnsembleModel:
def __init__(
self,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
model_server_url = build_model_server_url(model_server_host, model_server_port)
self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
def predict(self, query: str, passages: list[str]) -> list[list[float] | None]:
rerank_request = RerankRequest(query=query, documents=passages)
response = requests.post(
self.rerank_server_endpoint, json=rerank_request.dict()
)
response.raise_for_status()
return RerankResponse(**response.json()).scores
class IntentModel:
def __init__(
self,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
model_server_url = build_model_server_url(model_server_host, model_server_port)
self.intent_server_endpoint = model_server_url + "/custom/intent-model"
def predict(
self,
query: str,
) -> list[float]:
intent_request = IntentRequest(query=query)
response = requests.post(
self.intent_server_endpoint, json=intent_request.dict()
)
response.raise_for_status()
return IntentResponse(**response.json()).class_probs
def warm_up_encoders(
embedding_model: DBEmbeddingModel,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
model_name = embedding_model.model_name
normalize = embedding_model.normalize
provider_type = embedding_model.provider_type
warm_up_str = (
"Danswer is amazing! Check out our easy deployment guide at "
"https://docs.danswer.dev/quickstart"
)
# May not be the exact same tokenizer used for the indexing flow
logger.debug(f"Warming up encoder model: {model_name}")
get_tokenizer(model_name=model_name, provider_type=provider_type).encode(
warm_up_str
)
embed_model = EmbeddingModel(
model_name=model_name,
normalize=normalize,
provider_type=provider_type,
# Not a big deal if prefix is incorrect
query_prefix=None,
passage_prefix=None,
server_host=model_server_host,
server_port=model_server_port,
api_key=None,
)
# First time downloading the models it may take even longer, but just in case,
# retry the whole server
wait_time = 5
for _ in range(20):
try:
embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
return
except Exception:
logger.exception(
f"Failed to run test embedding, retrying in {wait_time} seconds..."
)
time.sleep(wait_time)
raise Exception("Failed to run test embedding.")

View File

@@ -0,0 +1,149 @@
import os
from abc import ABC
from abc import abstractmethod
from copy import copy
from transformers import logging as transformer_logging # type:ignore
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
logger = setup_logger()
transformer_logging.set_verbosity_error()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
class BaseTokenizer(ABC):
@abstractmethod
def encode(self, string: str) -> list[int]:
pass
@abstractmethod
def tokenize(self, string: str) -> list[str]:
pass
@abstractmethod
def decode(self, tokens: list[int]) -> str:
pass
class TiktokenTokenizer(BaseTokenizer):
_instances: dict[str, "TiktokenTokenizer"] = {}
def __new__(cls, encoding_name: str = "cl100k_base") -> "TiktokenTokenizer":
if encoding_name not in cls._instances:
cls._instances[encoding_name] = super(TiktokenTokenizer, cls).__new__(cls)
return cls._instances[encoding_name]
def __init__(self, encoding_name: str = "cl100k_base"):
if not hasattr(self, "encoder"):
import tiktoken
self.encoder = tiktoken.get_encoding(encoding_name)
def encode(self, string: str) -> list[int]:
# this returns no special tokens
return self.encoder.encode_ordinary(string)
def tokenize(self, string: str) -> list[str]:
return [self.encoder.decode([token]) for token in self.encode(string)]
def decode(self, tokens: list[int]) -> str:
return self.encoder.decode(tokens)
class HuggingFaceTokenizer(BaseTokenizer):
def __init__(self, model_name: str):
from tokenizers import Tokenizer # type: ignore
self.encoder = Tokenizer.from_pretrained(model_name)
def encode(self, string: str) -> list[int]:
# this returns no special tokens
return self.encoder.encode(string, add_special_tokens=False).ids
def tokenize(self, string: str) -> list[str]:
return self.encoder.encode(string, add_special_tokens=False).tokens
def decode(self, tokens: list[int]) -> str:
return self.encoder.decode(tokens)
_TOKENIZER_CACHE: dict[str, BaseTokenizer] = {}
def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
global _TOKENIZER_CACHE
if tokenizer_name not in _TOKENIZER_CACHE:
if tokenizer_name == "openai":
_TOKENIZER_CACHE[tokenizer_name] = TiktokenTokenizer("cl100k_base")
return _TOKENIZER_CACHE[tokenizer_name]
try:
logger.debug(f"Initializing HuggingFaceTokenizer for: {tokenizer_name}")
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(tokenizer_name)
except Exception as primary_error:
logger.error(
f"Error initializing HuggingFaceTokenizer for {tokenizer_name}: {primary_error}"
)
logger.warning(
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
)
try:
# Cache this tokenizer name to the default so we don't have to try to load it again
# and fail again
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(
DOCUMENT_ENCODER_MODEL
)
except Exception as fallback_error:
logger.error(
f"Error initializing fallback HuggingFaceTokenizer: {fallback_error}"
)
raise ValueError(
f"Failed to initialize tokenizer for {tokenizer_name} and fallback model"
) from fallback_error
return _TOKENIZER_CACHE[tokenizer_name]
def get_tokenizer(model_name: str | None, provider_type: str | None) -> BaseTokenizer:
if provider_type:
if provider_type.lower() == "openai":
# Used across ada and text-embedding-3 models
return _check_tokenizer_cache("openai")
# If we are given a cloud provider_type that isn't OpenAI, we default to trying to use the model_name
# this means we are approximating the token count which may leave some performance on the table
if not model_name:
raise ValueError("Need to provide a model_name or provider_type")
return _check_tokenizer_cache(model_name)
def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: BaseTokenizer
) -> str:
tokens = tokenizer.encode(content)
if len(tokens) > desired_length:
content = tokenizer.decode(tokens[:desired_length])
return content
def tokenizer_trim_chunks(
chunks: list[InferenceChunk],
tokenizer: BaseTokenizer,
max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE,
) -> list[InferenceChunk]:
new_chunks = copy(chunks)
for ind, chunk in enumerate(new_chunks):
new_content = tokenizer_trim_content(chunk.content, max_chunk_toks, tokenizer)
if len(new_content) != len(chunk.content):
new_chunk = copy(chunk)
new_chunk.content = new_content
new_chunks[ind] = new_chunk
return new_chunks

View File

@@ -9,10 +9,12 @@ from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerContexts
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import DocumentRelevance
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import LLMRelevanceSummaryResponse
from danswer.chat.models import QADocsResponse
from danswer.chat.models import RelevanceAnalysis
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.constants import MessageType
@@ -34,23 +36,22 @@ from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.models import QuotesConfig
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.utils import get_default_llm_token_encode
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.one_shot_answer.models import QueryRephrase
from danswer.one_shot_answer.qa_utils import combine_message_thread
from danswer.search.enums import LLMEvaluationType
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
from danswer.search.utils import drop_llm_indices
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.utils import get_json_line
from danswer.tools.force import ForceUseTool
from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID
from danswer.tools.search.search_tool import SEARCH_EVALUATION_ID
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
@@ -74,7 +75,7 @@ AnswerObjectIterator = Iterator[
| ChatMessageDetail
| CitationInfo
| ToolCallKickoff
| LLMRelevanceSummaryResponse
| DocumentRelevance
]
@@ -117,8 +118,12 @@ def stream_answer_objects(
one_shot=True,
danswerbot_flow=danswerbot_flow,
)
llm, fast_llm = get_llms_for_persona(persona=chat_session.persona)
llm_tokenizer = get_default_llm_token_encode()
llm_tokenizer = get_tokenizer(
model_name=llm.config.model_name,
provider_type=llm.config.model_provider,
)
# Create a chat session which will just store the root message, the query, and the AI response
root_message = get_or_create_root_message(
@@ -126,10 +131,12 @@ def stream_answer_objects(
)
history_str = combine_message_thread(
messages=history, max_tokens=max_history_tokens
messages=history,
max_tokens=max_history_tokens,
llm_tokenizer=llm_tokenizer,
)
rephrased_query = thread_based_query_rephrase(
rephrased_query = query_req.query_override or thread_based_query_rephrase(
user_query=query_msg.message,
history_str=history_str,
)
@@ -158,13 +165,12 @@ def stream_answer_objects(
parent_message=root_message,
prompt_id=query_req.prompt_id,
message=query_msg.message,
token_count=len(llm_tokenizer(query_msg.message)),
token_count=len(llm_tokenizer.encode(query_msg.message)),
message_type=MessageType.USER,
db_session=db_session,
commit=True,
)
llm, fast_llm = get_llms_for_persona(persona=chat_session.persona)
prompt_config = PromptConfig.from_model(prompt)
document_pruning_config = DocumentPruningConfig(
max_chunks=int(
@@ -179,6 +185,9 @@ def stream_answer_objects(
search_tool = SearchTool(
db_session=db_session,
user=user,
evaluation_type=LLMEvaluationType.SKIP
if DISABLE_LLM_DOC_RELEVANCE
else query_req.evaluation_type,
persona=chat_session.persona,
retrieval_options=query_req.retrieval_options,
prompt_config=prompt_config,
@@ -189,7 +198,6 @@ def stream_answer_objects(
chunks_below=query_req.chunks_below,
full_doc=query_req.full_doc,
bypass_acl=bypass_acl,
llm_doc_eval=query_req.llm_doc_eval,
)
answer_config = AnswerStyleConfig(
@@ -206,6 +214,7 @@ def stream_answer_objects(
single_message_history=history_str,
tools=[search_tool],
force_use_tool=ForceUseTool(
force_use=True,
tool_name=search_tool.name,
args={"query": rephrased_query},
),
@@ -217,7 +226,6 @@ def stream_answer_objects(
)
# won't be any ImageGenerationDisplay responses since that tool is never passed in
dropped_inds: list[int] = []
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
# for one-shot flow, don't currently do anything with these
@@ -256,24 +264,22 @@ def stream_answer_objects(
)
yield initial_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
chunk_indices = packet.response
if reference_db_search_docs is not None and dropped_inds:
chunk_indices = drop_llm_indices(
llm_indices=chunk_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_inds,
)
yield LLMRelevanceFilterResponse(relevant_chunk_indices=packet.response)
elif packet.id == SEARCH_DOC_CONTENT_ID:
yield packet.response
elif packet.id == SEARCH_EVALUATION_ID:
evaluation_response = LLMRelevanceSummaryResponse(
relevance_summaries=packet.response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
document_based_response = {}
if packet.response is not None:
for evaluation in packet.response:
document_based_response[
evaluation.document_id
] = RelevanceAnalysis(
relevant=evaluation.relevant, content=evaluation.content
)
evaluation_response = DocumentRelevance(
relevance_summaries=document_based_response
)
if reference_db_search_docs is not None:
update_search_docs_table_with_relevance(
@@ -291,7 +297,7 @@ def stream_answer_objects(
parent_message=new_user_message,
prompt_id=query_req.prompt_id,
message=answer.llm_answer,
token_count=len(llm_tokenizer(answer.llm_answer)),
token_count=len(llm_tokenizer.encode(answer.llm_answer)),
message_type=MessageType.ASSISTANT,
error=None,
reference_docs=reference_db_search_docs,

View File

@@ -9,6 +9,7 @@ from danswer.chat.models import DanswerContexts
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import QADocsResponse
from danswer.configs.constants import MessageType
from danswer.search.enums import LLMEvaluationType
from danswer.search.models import ChunkContext
from danswer.search.models import RetrievalDetails
@@ -27,17 +28,19 @@ class DirectQARequest(ChunkContext):
messages: list[ThreadMessage]
prompt_id: int | None
persona_id: int
agentic: bool | None = None
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
# This is to forcibly skip (or run) the step, if None it uses the system defaults
skip_rerank: bool | None = None
skip_llm_chunk_filter: bool | None = None
evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED
chain_of_thought: bool = False
return_contexts: bool = False
# This is to toggle agentic evaluation:
# 1. Evaluates whether each response is relevant or not
# 2. Provides a summary of the document's relevance in the resulsts
llm_doc_eval: bool = False
# allows the caller to specify the exact search query they want to use
# can be used if the message sent to the LLM / query should not be the same
# will also disable Thread-based Rewording if specified
query_override: str | None = None
# If True, skips generative an AI response to the search query
skip_gen_ai_answer_generation: bool = False

View File

@@ -1,8 +1,7 @@
from collections.abc import Callable
from collections.abc import Generator
from danswer.configs.constants import MessageType
from danswer.llm.utils import get_default_llm_token_encode
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.one_shot_answer.models import ThreadMessage
from danswer.utils.logger import setup_logger
@@ -18,7 +17,7 @@ def simulate_streaming_response(model_out: str) -> Generator[str, None, None]:
def combine_message_thread(
messages: list[ThreadMessage],
max_tokens: int | None,
llm_tokenizer: Callable | None = None,
llm_tokenizer: BaseTokenizer,
) -> str:
"""Used to create a single combined message context from threads"""
if not messages:
@@ -26,8 +25,6 @@ def combine_message_thread(
message_strs: list[str] = []
total_token_count = 0
if llm_tokenizer is None:
llm_tokenizer = get_default_llm_token_encode()
for message in reversed(messages):
if message.role == MessageType.USER:
@@ -42,7 +39,7 @@ def combine_message_thread(
role_str = message.role.value.upper()
msg_str = f"{role_str}:\n{message.message}"
message_token_count = len(llm_tokenizer(msg_str))
message_token_count = len(llm_tokenizer.encode(msg_str))
if (
max_tokens is not None

View File

@@ -28,7 +28,8 @@ True or False
"""
AGENTIC_SEARCH_USER_PROMPT = """
Document:
Document Title: {title}{optional_metadata}
```
{content}
```

View File

@@ -7,7 +7,6 @@ from danswer.prompts.constants import FINAL_QUERY_PAT
from danswer.prompts.constants import GENERAL_SEP_PAT
from danswer.prompts.constants import QUESTION_PAT
from danswer.prompts.constants import THOUGHT_PAT
from danswer.prompts.constants import UNCERTAINTY_PAT
ONE_SHOT_SYSTEM_PROMPT = """
@@ -66,9 +65,6 @@ EMPTY_SAMPLE_JSON = {
}
ANSWER_NOT_FOUND_RESPONSE = f'{{"answer": "{UNCERTAINTY_PAT}", "quotes": []}}'
# Default json prompt which can reference multiple docs and provide answer + quotes
# system_like_header is similar to system message, can be user provided or defaults to QA_HEADER
# context/history blocks are for context documents and conversation history, they can be blank

View File

@@ -5,12 +5,15 @@
USEFUL_PAT = "Yes useful"
NONUSEFUL_PAT = "Not useful"
SECTION_FILTER_PROMPT = f"""
Determine if the reference section is USEFUL for answering the user query.
Determine if the following section is USEFUL for answering the user query.
It is NOT enough for the section to be related to the query, \
it must contain information that is USEFUL for answering the query.
If the section contains ANY useful information, that is good enough, \
it does not need to fully answer the every part of the user query.
Title: {{title}}
{{optional_metadata}}
Reference Section:
```
{{chunk_text}}

View File

@@ -4,13 +4,6 @@ search/models.py imports from db/models.py."""
from enum import Enum
class OptionalSearchSetting(str, Enum):
ALWAYS = "always"
NEVER = "never"
# Determine whether to run search based on history and latest query
AUTO = "auto"
class RecencyBiasSetting(str, Enum):
FAVOR_RECENT = "favor_recent" # 2x decay rate
BASE_DECAY = "base_decay"
@@ -19,12 +12,26 @@ class RecencyBiasSetting(str, Enum):
AUTO = "auto"
class OptionalSearchSetting(str, Enum):
ALWAYS = "always"
NEVER = "never"
# Determine whether to run search based on history and latest query
AUTO = "auto"
class SearchType(str, Enum):
KEYWORD = "keyword"
SEMANTIC = "semantic"
HYBRID = "hybrid"
class LLMEvaluationType(str, Enum):
AGENTIC = "agentic" # applies agentic evaluation
BASIC = "basic" # applies boolean evaluation
SKIP = "skip" # skips evaluation
UNSPECIFIED = "unspecified" # reverts to default
class QueryFlow(str, Enum):
SEARCH = "search"
QUESTION_ANSWER = "question-answer"

View File

@@ -6,13 +6,13 @@ from pydantic import validator
from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.chat_configs import HYBRID_ALPHA
from danswer.configs.chat_configs import NUM_RERANKED_RESULTS
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.constants import DocumentSource
from danswer.db.models import Persona
from danswer.indexing.models import BaseChunk
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import SearchType
from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW
@@ -78,7 +78,7 @@ class SearchRequest(ChunkContext):
hybrid_alpha: float = HYBRID_ALPHA
# This is to forcibly skip (or run) the step, if None it uses the system defaults
skip_rerank: bool | None = None
skip_llm_chunk_filter: bool | None = None
evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED
class Config:
arbitrary_types_allowed = True
@@ -88,11 +88,11 @@ class SearchQuery(ChunkContext):
query: str
filters: IndexFilters
recency_bias_multiplier: float
evaluation_type: LLMEvaluationType
num_hits: int = NUM_RETURNED_HITS
offset: int = 0
search_type: SearchType = SearchType.HYBRID
skip_rerank: bool = not ENABLE_RERANKING_REAL_TIME_FLOW
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER
# Only used if not skip_rerank
num_rerank: int | None = NUM_RERANKED_RESULTS
# Only used if not skip_llm_chunk_filter
@@ -126,6 +126,7 @@ class InferenceChunk(BaseChunk):
document_id: str
source_type: DocumentSource
semantic_identifier: str
title: str | None # Separate from Semantic Identifier though often same
boost: int
recency_bias: float
score: float | None
@@ -193,16 +194,16 @@ class InferenceChunk(BaseChunk):
class InferenceChunkUncleaned(InferenceChunk):
title: str | None # Separate from Semantic Identifier though often same
metadata_suffix: str | None
def to_inference_chunk(self) -> InferenceChunk:
# Create a dict of all fields except 'title' and 'metadata_suffix'
# Create a dict of all fields except 'metadata_suffix'
# Assumes the cleaning has already been applied and just needs to translate to the right type
inference_chunk_data = {
k: v
for k, v in self.dict().items()
if k not in ["title", "metadata_suffix"]
if k
not in ["metadata_suffix"] # May be other fields to throw out in the future
}
return InferenceChunk(**inference_chunk_data)

View File

@@ -5,17 +5,19 @@ from typing import cast
from sqlalchemy.orm import Session
from danswer.chat.models import RelevanceChunk
from danswer.configs.chat_configs import DISABLE_AGENTIC_SEARCH
from danswer.chat.models import SectionRelevancePiece
from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prune_and_merge import _merge_sections
from danswer.llm.answering.prune_and_merge import ChunkRange
from danswer.llm.answering.prune_and_merge import merge_chunk_intervals
from danswer.llm.interfaces import LLM
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import IndexFilters
@@ -29,11 +31,13 @@ from danswer.search.postprocessing.postprocessing import search_postprocessing
from danswer.search.preprocessing.preprocessing import retrieval_preprocessing
from danswer.search.retrieval.search_runner import retrieve_chunks
from danswer.search.utils import inference_section_from_chunks
from danswer.search.utils import relevant_sections_to_indices
from danswer.secondary_llm_flows.agentic_evaluation import evaluate_inference_section
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from danswer.utils.timing import log_function_time
logger = setup_logger()
@@ -83,11 +87,13 @@ class SearchPipeline:
# Reranking and LLM section selection can be run together
# If only LLM selection is on, the reranked chunks are yielded immediatly
self._reranked_sections: list[InferenceSection] | None = None
self._relevant_section_indices: list[int] | None = None
self._final_context_sections: list[InferenceSection] | None = None
self._section_relevance: list[SectionRelevancePiece] | None = None
# Generates reranked chunks and LLM selections
self._postprocessing_generator: (
Iterator[list[InferenceSection] | list[int]] | None
Iterator[list[InferenceSection] | list[SectionRelevancePiece]] | None
) = None
"""Pre-processing"""
@@ -154,6 +160,7 @@ class SearchPipeline:
return cast(list[InferenceChunk], self._retrieved_chunks)
@log_function_time(print_only=True)
def _get_sections(self) -> list[InferenceSection]:
"""Returns an expanded section from each of the chunks.
If whole docs (instead of above/below context) is specified then it will give back all of the whole docs
@@ -173,9 +180,11 @@ class SearchPipeline:
expanded_inference_sections = []
# Full doc setting takes priority
if self.search_query.full_doc:
seen_document_ids = set()
unique_chunks = []
# This preserves the ordering since the chunks are retrieved in score order
for chunk in retrieved_chunks:
if chunk.document_id not in seen_document_ids:
@@ -195,7 +204,6 @@ class SearchPipeline:
),
)
)
list_inference_chunks = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=False
)
@@ -240,32 +248,35 @@ class SearchPipeline:
merged_ranges = [
merge_chunk_intervals(ranges) for ranges in doc_chunk_ranges_map.values()
]
flat_ranges = [r for ranges in merged_ranges for r in ranges]
flat_ranges: list[ChunkRange] = [r for ranges in merged_ranges for r in ranges]
flattened_inference_chunks: list[InferenceChunk] = []
parallel_functions_with_args = []
for chunk_range in flat_ranges:
functions_with_args.append(
(
# If Large Chunks are introduced, additional filters need to be added here
self.document_index.id_based_retrieval,
(
# Only need the document_id here, just use any chunk in the range is fine
chunk_range.chunks[0].document_id,
chunk_range.start,
chunk_range.end,
# There is no chunk level permissioning, this expansion around chunks
# can be assumed to be safe
IndexFilters(access_control_list=None),
),
)
)
# Don't need to fetch chunks within range for merging if chunk_above / below are 0.
if above == below == 0:
flattened_inference_chunks.extend(chunk_range.chunks)
# list of list of inference chunks where the inner list needs to be combined for content
list_inference_chunks = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=False
)
flattened_inference_chunks = [
chunk for sublist in list_inference_chunks for chunk in sublist
]
else:
parallel_functions_with_args.append(
(
self.document_index.id_based_retrieval,
(
chunk_range.chunks[0].document_id,
chunk_range.start,
chunk_range.end,
IndexFilters(access_control_list=None),
),
)
)
if parallel_functions_with_args:
list_inference_chunks = run_functions_tuples_in_parallel(
parallel_functions_with_args, allow_failures=False
)
for inference_chunks in list_inference_chunks:
flattened_inference_chunks.extend(inference_chunks)
doc_chunk_ind_to_chunk = {
(chunk.document_id, chunk.chunk_id): chunk
@@ -326,44 +337,71 @@ class SearchPipeline:
return self._reranked_sections
@property
def relevant_section_indices(self) -> list[int]:
if self._relevant_section_indices is not None:
return self._relevant_section_indices
def final_context_sections(self) -> list[InferenceSection]:
if self._final_context_sections is not None:
return self._final_context_sections
self._relevant_section_indices = next(
cast(Iterator[list[int]], self._postprocessing_generator)
)
return self._relevant_section_indices
self._final_context_sections = _merge_sections(sections=self.reranked_sections)
return self._final_context_sections
@property
def relevance_summaries(self) -> dict[str, RelevanceChunk]:
if DISABLE_AGENTIC_SEARCH:
def section_relevance(self) -> list[SectionRelevancePiece] | None:
if self._section_relevance is not None:
return self._section_relevance
if (
self.search_query.evaluation_type == LLMEvaluationType.SKIP
or DISABLE_LLM_DOC_RELEVANCE
):
return None
if self.search_query.evaluation_type == LLMEvaluationType.UNSPECIFIED:
raise ValueError(
"Agentic saerch operation called while DISABLE_AGENTIC_SEARCH is toggled"
"Attempted to access section relevance scores on search query with evaluation type `UNSPECIFIED`."
+ "The search query evaluation type should have been specified."
)
if len(self.reranked_sections) == 0:
logger.warning(
"No sections found in agentic search evalution. Returning empty dict."
if self.search_query.evaluation_type == LLMEvaluationType.AGENTIC:
sections = self.final_context_sections
functions = [
FunctionCall(
evaluate_inference_section,
(section, self.search_query.query, self.llm),
)
for section in sections
]
try:
results = run_functions_in_parallel(function_calls=functions)
self._section_relevance = list(results.values())
except Exception:
raise ValueError(
"An issue occured during the agentic evaluation proecss."
)
elif self.search_query.evaluation_type == LLMEvaluationType.BASIC:
if DISABLE_LLM_DOC_RELEVANCE:
raise ValueError(
"Basic search evaluation operation called while DISABLE_LLM_DOC_RELEVANCE is enabled."
)
self._section_relevance = next(
cast(
Iterator[list[SectionRelevancePiece]],
self._postprocessing_generator,
)
)
return {}
sections = self.reranked_sections
functions = [
FunctionCall(
evaluate_inference_section, (section, self.search_query.query, self.llm)
else:
# All other cases should have been handled above
raise ValueError(
f"Unexpected evaluation type: {self.search_query.evaluation_type}"
)
for section in sections
]
results = run_functions_in_parallel(function_calls=functions)
return {
next(iter(value)): value[next(iter(value))] for value in results.values()
}
return self._section_relevance
@property
def section_relevance_list(self) -> list[bool]:
return [
True if ind in self.relevant_section_indices else False
for ind in range(len(self.reranked_sections))
]
llm_indices = relevant_sections_to_indices(
relevance_sections=self.section_relevance,
items=self.final_context_sections,
)
return [ind in llm_indices for ind in range(len(self.final_context_sections))]

View File

@@ -4,7 +4,8 @@ from typing import cast
import numpy
from danswer.configs.constants import MAX_CHUNK_TITLE_LEN
from danswer.chat.models import SectionRelevancePiece
from danswer.configs.app_configs import BLURB_SIZE
from danswer.configs.constants import RETURN_SEPARATOR
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
@@ -12,6 +13,10 @@ from danswer.document_index.document_index_utils import (
translate_boost_count_to_multiplier,
)
from danswer.llm.interfaces import LLM
from danswer.natural_language_processing.search_nlp_models import (
CrossEncoderEnsembleModel,
)
from danswer.search.enums import LLMEvaluationType
from danswer.search.models import ChunkMetric
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceChunkUncleaned
@@ -20,7 +25,6 @@ from danswer.search.models import MAX_METRICS_CONTENT
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.models import SearchType
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
@@ -46,10 +50,6 @@ def should_rerank(query: SearchQuery) -> bool:
return query.search_type != SearchType.KEYWORD and not query.skip_rerank
def should_apply_llm_based_relevance_filter(query: SearchQuery) -> bool:
return not query.skip_llm_chunk_filter
def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk]:
def _remove_title(chunk: InferenceChunkUncleaned) -> str:
if not chunk.title or not chunk.content:
@@ -58,8 +58,14 @@ def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk
if chunk.content.startswith(chunk.title):
return chunk.content[len(chunk.title) :].lstrip()
if chunk.content.startswith(chunk.title[:MAX_CHUNK_TITLE_LEN]):
return chunk.content[MAX_CHUNK_TITLE_LEN:].lstrip()
# BLURB SIZE is by token instead of char but each token is at least 1 char
# If this prefix matches the content, it's assumed the title was prepended
if chunk.content.startswith(chunk.title[:BLURB_SIZE]):
return (
chunk.content.split(RETURN_SEPARATOR, 1)[-1]
if RETURN_SEPARATOR in chunk.content
else chunk.content
)
return chunk.content
@@ -91,7 +97,11 @@ def semantic_reranking(
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
"""
cross_encoders = CrossEncoderEnsembleModel()
passages = [chunk.content for chunk in chunks]
passages = [
f"{chunk.semantic_identifier or chunk.title or ''}\n{chunk.content}"
for chunk in chunks
]
sim_scores_floats = cross_encoders.predict(query=query, passages=passages)
sim_scores = [numpy.array(scores) for scores in sim_scores_floats]
@@ -202,11 +212,17 @@ def filter_sections(
section.center_chunk.content if use_chunk else section.combined_content
for section in sections_to_filter
]
metadata_list = [section.center_chunk.metadata for section in sections_to_filter]
titles = [
section.center_chunk.semantic_identifier for section in sections_to_filter
]
llm_chunk_selection = llm_batch_eval_sections(
query=query.query,
section_contents=contents,
llm=llm,
titles=titles,
metadata_list=metadata_list,
)
return [
@@ -221,9 +237,15 @@ def search_postprocessing(
retrieved_sections: list[InferenceSection],
llm: LLM,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> Iterator[list[InferenceSection] | list[int]]:
) -> Iterator[list[InferenceSection] | list[SectionRelevancePiece]]:
post_processing_tasks: list[FunctionCall] = []
if not retrieved_sections:
# Avoids trying to rerank an empty list which throws an error
yield []
yield []
return
rerank_task_id = None
sections_yielded = False
if should_rerank(search_query):
@@ -247,7 +269,10 @@ def search_postprocessing(
sections_yielded = True
llm_filter_task_id = None
if should_apply_llm_based_relevance_filter(search_query):
if search_query.evaluation_type in [
LLMEvaluationType.BASIC,
LLMEvaluationType.UNSPECIFIED,
]:
post_processing_tasks.append(
FunctionCall(
filter_sections,
@@ -288,7 +313,11 @@ def search_postprocessing(
)
yield [
index
for index, section in enumerate(reranked_sections or retrieved_sections)
if section.center_chunk.unique_id in llm_selected_section_ids
SectionRelevancePiece(
document_id=section.center_chunk.document_id,
chunk_id=section.center_chunk.chunk_id,
relevant=section.center_chunk.unique_id in llm_selected_section_ids,
content="",
)
for section in (reranked_sections or retrieved_sections)
]

View File

@@ -1,26 +1,20 @@
from typing import TYPE_CHECKING
from danswer.natural_language_processing.search_nlp_models import IntentModel
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.search.enums import QueryFlow
from danswer.search.models import SearchType
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
from danswer.search.search_nlp_models import get_default_tokenizer
from danswer.search.search_nlp_models import IntentModel
from danswer.server.query_and_chat.models import HelperResponse
from danswer.utils.logger import setup_logger
logger = setup_logger()
if TYPE_CHECKING:
from transformers import AutoTokenizer # type:ignore
def count_unk_tokens(text: str, tokenizer: "AutoTokenizer") -> int:
def count_unk_tokens(text: str, tokenizer: BaseTokenizer) -> int:
"""Unclear if the wordpiece/sentencepiece tokenizer used is actually tokenizing anything as the [UNK] token
It splits up even foreign characters and unicode emojis without using UNK"""
tokenized_text = tokenizer.tokenize(text)
num_unk_tokens = len(
[token for token in tokenized_text if token == tokenizer.unk_token]
)
num_unk_tokens = len([token for token in tokenized_text if token == "[UNK]"])
logger.debug(f"Total of {num_unk_tokens} UNKNOWN tokens found")
return num_unk_tokens
@@ -74,7 +68,12 @@ def recommend_search_flow(
# UNK tokens -> suggest Keyword (still may be valid QA)
# TODO do a better job with the classifier model and retire the heuristics
if count_unk_tokens(query, get_default_tokenizer(model_name=model_name)) > 0:
if (
count_unk_tokens(
query, get_tokenizer(model_name=model_name, provider_type=None)
)
> 0
):
if not keyword:
heuristic_search_type = SearchType.KEYWORD
message = "Unknown tokens in query."

View File

@@ -1,11 +1,12 @@
from sqlalchemy.orm import Session
from danswer.configs.chat_configs import BASE_RECENCY_DECAY
from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.db.models import User
from danswer.llm.interfaces import LLM
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import QueryFlow
from danswer.search.enums import RecencyBiasSetting
from danswer.search.models import BaseFilters
@@ -35,7 +36,6 @@ def retrieval_preprocessing(
db_session: Session,
bypass_acl: bool = False,
include_query_intent: bool = True,
disable_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
base_recency_decay: float = BASE_RECENCY_DECAY,
favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER,
) -> tuple[SearchQuery, SearchType | None, QueryFlow | None]:
@@ -137,18 +137,23 @@ def retrieval_preprocessing(
access_control_list=user_acl_filters,
)
llm_chunk_filter = False
if search_request.skip_llm_chunk_filter is not None:
llm_chunk_filter = not search_request.skip_llm_chunk_filter
elif persona:
llm_chunk_filter = persona.llm_relevance_filter
llm_evaluation_type = LLMEvaluationType.BASIC
if search_request.evaluation_type is not LLMEvaluationType.UNSPECIFIED:
llm_evaluation_type = search_request.evaluation_type
if disable_llm_chunk_filter:
if llm_chunk_filter:
elif persona:
llm_evaluation_type = (
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
)
if DISABLE_LLM_DOC_RELEVANCE:
if llm_evaluation_type:
logger.info(
"LLM chunk filtering would have run but has been globally disabled"
)
llm_chunk_filter = False
llm_evaluation_type = LLMEvaluationType.SKIP
skip_rerank = search_request.skip_rerank
if skip_rerank is None:
@@ -176,7 +181,7 @@ def retrieval_preprocessing(
num_hits=limit if limit is not None else NUM_RETURNED_HITS,
offset=offset or 0,
skip_rerank=skip_rerank,
skip_llm_chunk_filter=not llm_chunk_filter,
evaluation_type=llm_evaluation_type,
chunks_above=search_request.chunks_above,
chunks_below=search_request.chunks_below,
full_doc=search_request.full_doc,

View File

@@ -11,6 +11,7 @@ from danswer.configs.chat_configs import HYBRID_ALPHA
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.document_index.interfaces import DocumentIndex
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.search.models import ChunkMetric
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunk
@@ -20,7 +21,6 @@ from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.models import SearchType
from danswer.search.postprocessing.postprocessing import cleanup_chunks
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.search.utils import inference_section_from_chunks
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
from danswer.utils.logger import setup_logger

View File

@@ -1,211 +0,0 @@
import gc
import os
import time
from typing import Optional
from typing import TYPE_CHECKING
import requests
from transformers import logging as transformer_logging # type:ignore
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
from shared_configs.model_server_models import RerankRequest
from shared_configs.model_server_models import RerankResponse
transformer_logging.set_verbosity_error()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
logger = setup_logger()
if TYPE_CHECKING:
from transformers import AutoTokenizer # type: ignore
_TOKENIZER: tuple[Optional["AutoTokenizer"], str | None] = (None, None)
def clean_model_name(model_str: str) -> str:
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
# NOTE: If no model_name is specified, it may not be using the "correct" tokenizer
# for cases where this is more important, be sure to refresh with the actual model name
# One case where it is not particularly important is in the document chunking flow,
# they're basically all using the sentencepiece tokenizer and whether it's cased or
# uncased does not really matter, they'll all generally end up with the same chunk lengths.
def get_default_tokenizer(model_name: str = DOCUMENT_ENCODER_MODEL) -> "AutoTokenizer":
# NOTE: doing a local import here to avoid reduce memory usage caused by
# processes importing this file despite not using any of this
from transformers import AutoTokenizer # type: ignore
global _TOKENIZER
if _TOKENIZER[0] is None or _TOKENIZER[1] != model_name:
if _TOKENIZER[0] is not None:
del _TOKENIZER
gc.collect()
_TOKENIZER = (AutoTokenizer.from_pretrained(model_name), model_name)
if hasattr(_TOKENIZER[0], "is_fast") and _TOKENIZER[0].is_fast:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
return _TOKENIZER[0]
def build_model_server_url(
model_server_host: str,
model_server_port: int,
) -> str:
model_server_url = f"{model_server_host}:{model_server_port}"
# use protocol if provided
if "http" in model_server_url:
return model_server_url
# otherwise default to http
return f"http://{model_server_url}"
class EmbeddingModel:
def __init__(
self,
server_host: str, # Changes depending on indexing or inference
server_port: int,
model_name: str | None,
normalize: bool,
query_prefix: str | None,
passage_prefix: str | None,
api_key: str | None,
provider_type: str | None,
# The following are globals are currently not configurable
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
) -> None:
self.api_key = api_key
self.provider_type = provider_type
self.max_seq_length = max_seq_length
self.query_prefix = query_prefix
self.passage_prefix = passage_prefix
self.normalize = normalize
self.model_name = model_name
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
def encode(self, texts: list[str], text_type: EmbedTextType) -> list[list[float]]:
if text_type == EmbedTextType.QUERY and self.query_prefix:
prefixed_texts = [self.query_prefix + text for text in texts]
elif text_type == EmbedTextType.PASSAGE and self.passage_prefix:
prefixed_texts = [self.passage_prefix + text for text in texts]
else:
prefixed_texts = texts
embed_request = EmbedRequest(
model_name=self.model_name,
texts=prefixed_texts,
max_context_length=self.max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
provider_type=self.provider_type,
text_type=text_type,
)
response = requests.post(self.embed_server_endpoint, json=embed_request.dict())
response.raise_for_status()
return EmbedResponse(**response.json()).embeddings
class CrossEncoderEnsembleModel:
def __init__(
self,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
model_server_url = build_model_server_url(model_server_host, model_server_port)
self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
def predict(self, query: str, passages: list[str]) -> list[list[float]]:
rerank_request = RerankRequest(query=query, documents=passages)
response = requests.post(
self.rerank_server_endpoint, json=rerank_request.dict()
)
response.raise_for_status()
return RerankResponse(**response.json()).scores
class IntentModel:
def __init__(
self,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
model_server_url = build_model_server_url(model_server_host, model_server_port)
self.intent_server_endpoint = model_server_url + "/custom/intent-model"
def predict(
self,
query: str,
) -> list[float]:
intent_request = IntentRequest(query=query)
response = requests.post(
self.intent_server_endpoint, json=intent_request.dict()
)
response.raise_for_status()
return IntentResponse(**response.json()).class_probs
def warm_up_encoders(
model_name: str,
normalize: bool,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
warm_up_str = (
"Danswer is amazing! Check out our easy deployment guide at "
"https://docs.danswer.dev/quickstart"
)
# May not be the exact same tokenizer used for the indexing flow
get_default_tokenizer(model_name=model_name)(warm_up_str)
embed_model = EmbeddingModel(
model_name=model_name,
normalize=normalize,
# Not a big deal if prefix is incorrect
query_prefix=None,
passage_prefix=None,
server_host=model_server_host,
server_port=model_server_port,
api_key=None,
provider_type=None,
)
# First time downloading the models it may take even longer, but just in case,
# retry the whole server
wait_time = 5
for attempt in range(20):
try:
embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
return
except Exception:
logger.exception(
f"Failed to run test embedding, retrying in {wait_time} seconds..."
)
time.sleep(wait_time)
raise Exception("Failed to run test embedding.")

View File

@@ -1,6 +1,7 @@
from collections.abc import Sequence
from typing import TypeVar
from danswer.chat.models import SectionRelevancePiece
from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
@@ -18,6 +19,14 @@ T = TypeVar(
SavedSearchDocWithContent,
)
TSection = TypeVar(
"TSection",
InferenceSection,
SearchDoc,
SavedSearchDoc,
SavedSearchDocWithContent,
)
def dedupe_documents(items: list[T]) -> tuple[list[T], list[int]]:
seen_ids = set()
@@ -37,6 +46,35 @@ def dedupe_documents(items: list[T]) -> tuple[list[T], list[int]]:
return deduped_items, dropped_indices
def relevant_sections_to_indices(
relevance_sections: list[SectionRelevancePiece] | None, items: list[TSection]
) -> list[int]:
if not relevance_sections:
return []
relevant_set = {
(chunk.document_id, chunk.chunk_id)
for chunk in relevance_sections
if chunk.relevant
}
return [
index
for index, item in enumerate(items)
if (
(
isinstance(item, InferenceSection)
and (item.center_chunk.document_id, item.center_chunk.chunk_id)
in relevant_set
)
or (
not isinstance(item, (InferenceSection))
and (item.document_id, item.chunk_ind) in relevant_set
)
)
]
def drop_llm_indices(
llm_indices: list[int],
search_docs: Sequence[DBSearchDoc | SavedSearchDoc],

Some files were not shown because too many files have changed in this diff Show More