Compare commits

...

252 Commits

Author SHA1 Message Date
Yuhong Sun
4293543a6a k 2024-07-20 16:48:05 -07:00
Yuhong Sun
e95bfa0e0b Suffix Test (#1880) 2024-07-20 15:54:55 -07:00
Yuhong Sun
4848b5f1de Suffix Edits (#1878) 2024-07-20 13:59:14 -07:00
Yuhong Sun
7ba5c434fa Missing Comma (#1877) 2024-07-19 22:15:45 -07:00
Yuhong Sun
59bf5ba848 File Connector Metadata (#1876) 2024-07-19 20:45:18 -07:00
Weves
f66c33380c Improve widget README 2024-07-19 20:21:07 -07:00
Weves
115650ce9f Add example widget code 2024-07-19 20:14:52 -07:00
Weves
7aa3602fca Fix black 2024-07-19 18:55:09 -07:00
Weves
864c552a17 Fix UT 2024-07-19 18:55:09 -07:00
Brent Kwok
07b2ed3d8f Fix HTTP 422 error for api_inference_sample.py (#1868) 2024-07-19 18:54:43 -07:00
Yuhong Sun
38290057f2 Search Eval (#1873) 2024-07-19 16:48:58 -07:00
Weves
2344edf158 Change default login time to 7 days 2024-07-19 13:58:50 -07:00
versecafe
86d1804eb0 Add GPT-4o-Mini & fix a missing gpt-4o 2024-07-19 12:10:27 -07:00
pablodanswer
1ebae50d0c minor udpate 2024-07-19 10:53:28 -07:00
Weves
a9fbaa396c Stop building on every PR 2024-07-19 10:21:19 -07:00
pablodanswer
27d5f69427 udpate to headers (#1864) 2024-07-19 08:38:54 -07:00
pablodanswer
5d98421ae8 show "analysis" (#1863) 2024-07-18 18:18:36 -07:00
Kevin Shi
6b561b8ca9 Add config to skip zendesk article labels 2024-07-18 18:00:51 -07:00
pablodanswer
2dc7e64dd7 fix internet search icons / text + assistants tab (#1862) 2024-07-18 16:15:19 -07:00
Yuhong Sun
5230f7e22f Enforce Disable GenAI if set (#1860) 2024-07-18 13:25:55 -07:00
hagen-danswer
a595d43ae3 Fixed deleting toolcall by message 2024-07-18 12:52:28 -07:00
Yuhong Sun
ee561f42ff Cleaner Layout (#1857) 2024-07-18 11:13:16 -07:00
Yuhong Sun
f00b3d76b3 Touchup NoOp (#1856) 2024-07-18 08:44:27 -07:00
Yuhong Sun
e4984153c0 Touchups (#1855) 2024-07-17 23:47:10 -07:00
pablodanswer
87fadb07ea COMPLETE USER EXPERIENCE OVERHAUL (#1822) 2024-07-17 19:44:21 -07:00
pablodanswer
2b07c102f9 fix discourse connector rate limiting + topic fetching (#1820) 2024-07-17 14:57:40 -07:00
hagen-danswer
e93de602c3 Use SHA instead of branch and save more data (#1850) 2024-07-17 14:56:24 -07:00
hagen-danswer
1c77395503 Fixed llm_indices from document search api (#1853) 2024-07-17 14:52:49 -07:00
Victorivus
cdf6089b3e Fix bug XML files in chat (#1804) 2024-07-17 08:09:40 -07:00
pablodanswer
d01f46af2b fix search doc bug (#1851) 2024-07-16 15:27:04 -07:00
hagen-danswer
b83f435bb0 Catch dropped eval questions and added multiprocessing (#1849) 2024-07-16 12:33:02 -07:00
hagen-danswer
25b3dacaba Seperated model caching volumes (#1845) 2024-07-15 15:32:04 -07:00
hagen-danswer
a1e638a73d Improved eval logging and stability (#1843) 2024-07-15 14:58:45 -07:00
Yuhong Sun
bd1e0c5969 Add Enum File (#1842) 2024-07-15 09:13:27 -07:00
Yuhong Sun
4d295ab97d Model Server Logging (#1839) 2024-07-15 09:00:27 -07:00
Weves
6fe3eeaa48 Fix model serer startup 2024-07-14 23:33:58 -07:00
Chris Weaver
078d5defbb Update Slack link in README.md 2024-07-14 16:50:48 -07:00
Weves
0d52e99bd4 Improve confluence rate limiting 2024-07-14 16:40:45 -07:00
hagen-danswer
1b864a00e4 Added support for multiple Eval Pipeline UIs (#1830) 2024-07-14 15:16:20 -07:00
Weves
dae4f6a0bd Fix latency caused by large numbers of tags 2024-07-14 14:21:07 -07:00
Yuhong Sun
f63d0ca3ad Title Truncation Logic (#1828) 2024-07-14 13:54:36 -07:00
Yuhong Sun
da31da33e7 Fix Title for docs without (#1827) 2024-07-14 13:51:11 -07:00
Yuhong Sun
56b175f597 Fix Sitemap Robo (#1826) 2024-07-14 13:29:26 -07:00
Zoltan Szabo
1b311d092e Try to find the sitemap for a given site (#1538) 2024-07-14 13:24:10 -07:00
Moshe Zada
6ee1292757 Fix semantic id for web pdfs (#1823) 2024-07-14 11:38:11 -07:00
Yuhong Sun
017af052be Global Tokenizer Fix (#1825) 2024-07-14 11:37:10 -07:00
pablodanswer
e7f81d1688 add third party embedding models (#1818) 2024-07-14 10:19:53 -07:00
Weves
b6bd818e60 Fix user groups page when a persona is deleted 2024-07-13 15:35:50 -07:00
hagen-danswer
36da2e4b27 Fixed slack groups (#1814)
* Simplified slackbot response groups and fixed need more help bug

* mypy fixes

* added exceptions for the couldnt find passthrough arrays
2024-07-13 22:34:35 +00:00
pablodanswer
c7af6a4601 add new standard answer test endpoint (#1789) 2024-07-12 10:06:30 -07:00
Yuhong Sun
e90c66c1b6 Include Titles in Chunks (#1817) 2024-07-12 09:42:24 -07:00
hagen-danswer
8c312482c1 fixed id retrieval from zip metadata (#1813) 2024-07-11 20:38:12 -07:00
Weves
e50820e65e Remove Internet 'Connector' that mistakenly appears on the Add Connector page 2024-07-11 18:00:59 -07:00
hagen-danswer
991ee79e47 some qol improvements for search pipeline (#1809) 2024-07-11 17:42:11 -07:00
hagen-danswer
3e645a510e Fix slack error logging (#1800) 2024-07-11 08:31:48 -07:00
Yuhong Sun
08c6e821e7 Merge Sections Logic (#1801) 2024-07-10 20:14:02 -07:00
hagen-danswer
47a550221f slackbot doesnt respond without citations/quotes (#1798)
* slackbot doesnt respond without citations/quotes

fixed logical issues

fixed dict logic

* added slackbot shim for the llm source/time feature

* mypy fixes

* slackbot doesnt respond without citations/quotes

fixed logical issues

fixed dict logic

* Update handle_regular_answer.py

* added bypass_filter check

* final fixes
2024-07-11 00:18:26 +00:00
Weves
511f619212 Add content to /document-search response 2024-07-10 15:44:58 -07:00
Varun Gaur
6c51f001dc Confluence Connector to Sync Child pages only (#1629)
---------

Co-authored-by: Varun Gaur <vgaur@roku.com>
Co-authored-by: hagen-danswer <hagen@danswer.ai>
Co-authored-by: pablodanswer <pablo@danswer.ai>
2024-07-10 14:17:03 -07:00
pablodanswer
09a11b5e1a Fix citations + unit tests (#1760) 2024-07-10 10:05:20 -07:00
pablodanswer
aa0f7abdac add basic table wrapping (#1791) 2024-07-09 19:14:41 -07:00
Yuhong Sun
7c8f8dba17 Break the Danswer LLM logging from LiteLLM Verbose (#1795) 2024-07-09 18:18:29 -07:00
Yuhong Sun
39982e5fdc Info propagating to allow Chunk Merging (#1794) 2024-07-09 18:15:07 -07:00
pablodanswer
5e0de111f9 fix wrapping in error hover connector (#1790) 2024-07-09 11:54:35 -07:00
pablodanswer
727d80f168 fix gpt-4o image issue (#1786) 2024-07-08 23:07:53 +00:00
rashad-danswer
146f85936b Internet Search Tool (#1666)
---------

Co-authored-by: Weves <chrisweaver101@gmail.com>
2024-07-06 18:01:24 -07:00
Chris Weaver
e06f8a0a4b Standard Answers (#1753)
---------

Co-authored-by: druhinsgoel <druhin@danswer.ai>
2024-07-06 16:11:11 -07:00
Yuhong Sun
f0888f2f61 Eval Script Incremental Write (#1784) 2024-07-06 15:43:40 -07:00
Yuhong Sun
d35d7ee833 Evaluation Pipeline Touchup (#1783) 2024-07-06 13:17:05 -07:00
Yuhong Sun
c5bb3fde94 Ignore Eval Files (#1782) 2024-07-06 12:15:03 -07:00
Yuhong Sun
79190030a5 New Env File for Eval (#1781) 2024-07-06 12:07:31 -07:00
Yuhong Sun
8e8f262ed3 Docker Compose Eval Pipeline Cleanup (#1780) 2024-07-06 12:04:57 -07:00
hagen-danswer
ac14369716 Added search quality testing pipeline (#1774) 2024-07-06 11:51:50 -07:00
Weves
de4d8e9a65 Fix shared chats 2024-07-04 11:41:16 -07:00
hagen-danswer
0b384c5b34 fixed salesforce url generation (#1777) 2024-07-04 10:43:21 -07:00
Weves
fa049f4f98 Add UI support for github configs 2024-07-03 17:37:59 -07:00
pablodanswer
72d6a0ef71 minor updates to assistant UI (#1771) 2024-07-03 18:28:25 +00:00
pablodanswer
ae4e643266 Update Assistants Creation UI (#1714)
* slide up "Tools"

* rework assistants page

* update layout

* reorg complete

- pending: useful header text?

* add tooltips

* alter organizational structure

* rm shadcn

* rm dependencies

* revalidate dependencies

* restore

* update component structure

* [s] format

* rm package json

* add package-lock.json [s]

* collapsible

* naming + width

* formatting

* formatting

* updated user flow

- Fix error/detail messages
- Fix tooltip delay
- Fix icons

* 1 -> 2

* naming fixes

* ran pretty

* fix build issue?

* web build issues?
2024-07-03 17:11:14 +00:00
hagen-danswer
a7da07afc0 allowed arbitrary types to handle the sqlalchemy datatype (#1758)
* allowed arbitrary types to handle the sqlalchemy datatype

* changed persona_upsert to take in ids instead of objects
2024-07-03 07:10:57 +00:00
Weves
7f1bb67e52 Pass through API base to ImageGenerationTool 2024-07-02 23:31:04 -07:00
Weves
982b1b0c49 Add litellm.set_verbose support 2024-07-02 23:22:17 -07:00
Daniel Naber
2db128fb36 Notion date filter fix (#1755)
* fix filter logic

* make comparison better readable
2024-07-02 15:39:35 -07:00
Christoph Petzold
3ebac6256f Fix "cannot access local variable" for bot direct messages (#1737)
* Update handle_message.py

* Update handle_message.py

* Update handle_message.py

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
2024-07-02 15:36:23 -07:00
Weves
1a3ec59610 Fix build caused by bad seeding config 2024-07-01 23:41:43 -07:00
hagen-danswer
581cb827bb added settings and persona seeding options (#1742)
* added settings and persona seeding options

* updated recency_bias

* changed variable type

* another fix

* Update seeding.py

* fixed mypy

* push
2024-07-01 22:22:17 +00:00
Weves
393b3c9343 Fix misc chat bugs 2024-06-30 18:08:03 -07:00
Weves
2035e9f39c Fix docker build 2024-06-30 14:34:47 -07:00
Weves
52c3a5e9d2 Fix slackbot citation images 2024-06-30 13:42:56 -07:00
pablodanswer
3e45a41617 Bugfix/scroll (#1748) 2024-06-30 12:58:57 -07:00
Weves
415960564d Fix fast models 2024-06-29 15:19:09 -07:00
pablodanswer
ed550986a6 Feature/assistants (#1581)
* include alternate assisstant

- migrate models
- migrate db

* functional alternate assistant selection

* refactor chat components for persona API

* functional assistants api

* add full functionality- assistants

* add functional assistants dropdown handler

* refactor assistants for full compatability

- hooks
- track the live assistant for edge cases
- UI updates

* add assistant UI features

- Autotab
- Arrow selection
- Icons
- Proper @ detection
- Info Popup

prune unnecessary comments

* functional search toggling for assistants

* add functional cross-page assistants

rebase with main

* add proper interactivity for edge cases

- click outside of input / text box
- "force search" assistant consistency

* refactor alt assistant consistency

* update alembic versions

* rebased

* undo formatting changes

* additional formatting

* current processing

* merge fixes

* formatting

* colors

* 2 -> 1

* 1 -> 2

---------

Co-authored-by: “Pablo <“pablo@danswer.ai”>
2024-06-29 00:18:39 +00:00
hagen-danswer
60dd77393d Disallowed simultaneous pruning jobs (#1704)
* Added TTL to EE Celery tasks

* fixed alembic files

* fixed frontend build issue and reworked file deletion

* FileD

* revert change

* reworked delete chatmessage

* added orphan cleanup

* ensured syntax

* Disallowed simultaneous pruning jobs

* added rate limiting and env vars

* i hope this is how you use decorators

* nonsense

* cleaned up names, added config

* renamed other utils

* Update celery_utils.py

* reverted changes
2024-06-28 23:26:00 +00:00
hagen-danswer
3fe5313b02 renamed alembic table (#1741) 2024-06-28 22:54:19 +00:00
hagen-danswer
bd0925611a Added TTL to EE Celery tasks (#1713)
* Added TTL to EE Celery tasks

* fixed alembic files

* fixed frontend build issue and reworked file deletion

* FileD

* revert change

* reworked delete chatmessage

* added orphan cleanup

* ensured syntax

* default value to None

* made all deletions manual

* added fix

* Use tremor buttons now

* removed words

* Update 23957775e5f5_remove_feedback_foreignkey_constraint.py

* fixed alembic version
2024-06-28 22:13:47 +00:00
pablodanswer
de6d040349 add boto3 typing to default requirements (#1740) 2024-06-28 13:06:59 -07:00
Chris Weaver
38da3128d8 Pass headers into image generation (#1739) 2024-06-28 12:33:53 -07:00
Chris Weaver
e47da0d688 Small readme improvement (#1735) 2024-06-27 19:12:24 -07:00
rkuo-danswer
2c0e0c5f11 Merge pull request #1731 from danswer-ai/feature/merge-queue-workflows
add workflows that automate docker builds against merge group events
2024-06-27 18:25:05 -07:00
Richard Kuo (Danswer)
29d57f6354 remove obsolete comment 2024-06-27 17:39:17 -07:00
rkuo-danswer
369e607631 Fixes DAN-189 (safari bug in admin). Removed td/absolute positioning behavior which is unde… (#1718) 2024-06-27 17:15:42 -07:00
pablodanswer
f03f97307f Blob Storage (#1705)
S3 + OCI + Google Cloud Storage + R2
---------

Co-authored-by: Art Matsak <5328078+artmatsak@users.noreply.github.com>
2024-06-27 17:12:20 -07:00
Weves
145cdb69b7 Remove duplicate tool check 2024-06-27 17:05:50 -07:00
pablodanswer
9310a8edc2 Feature/scroll (#1694)
---------

Co-authored-by: “Pablo <“pablo@danswer.ai”>
2024-06-27 16:40:23 -07:00
Yuhong Sun
2140f80891 Tidy up Actions ported from EE (#1732) 2024-06-27 16:20:34 -07:00
Richard Kuo (Danswer)
52dab23295 add workflows that automate docker builds against merge group events 2024-06-27 14:31:39 -07:00
Weves
91c9b2eb42 Add more logging for num workers in simple job client 2024-06-27 11:49:08 -07:00
Weves
5764cdd469 Use FiEdit2 as the standard edit icon 2024-06-27 11:32:47 -07:00
Weves
8fea6d7f64 Fix share for insecure: 2024-06-27 11:19:04 -07:00
pablodanswer
5324b15397 Chat overflow (#1723) 2024-06-27 11:02:11 -07:00
Yuhong Sun
8be42a5f98 Touchup for Multilingual Users (#1725) 2024-06-26 22:44:06 -07:00
Weves
062dc98719 Fix search tool 2024-06-26 21:28:16 -07:00
pablodanswer
43557f738b add copy-paste images (#1722) 2024-06-26 20:12:34 -07:00
Weves
b5aa7370a2 Make seeded model default 2024-06-26 16:03:24 -07:00
Weves
4ba6e45128 Small template fix 2024-06-26 13:58:30 -07:00
pablodanswer
d6e5a98a22 Minor Update to UI (#1692)
New sidebar / chatbar color, hidable right-panel, and many more small tweaks.

---------

Co-authored-by: pablodanswer <“pablo@danswer.ai”>
2024-06-26 13:54:41 -07:00
Yuhong Sun
20c4cdbdda Catch LLM Generation Failure (#1712) 2024-06-26 10:44:22 -07:00
Yuhong Sun
0d814939ee Bugfix for Selected Doc when the message it is selected from failed (#1711) 2024-06-26 10:32:30 -07:00
Yuhong Sun
7d2b0ffcc5 Developer Env Setup (#1710) 2024-06-26 09:45:19 -07:00
hagen-danswer
8c6cd661f5 Ignore messages from Slack's official bot (#1703)
* Ignore messages from Slack's official bot

* re-added message filter
2024-06-25 20:33:36 -07:00
Weves
5d552705aa Change EE environment variable name 2024-06-25 16:59:09 -07:00
Weves
1ee8ee9e8b Prepare EE to merge with MIT 2024-06-25 15:07:56 -07:00
Chris Weaver
f0b2b57d81 Usage reports (#118)
---------

Co-authored-by: amohamdy99 <a.mohamdy99@gmail.com>
2024-06-25 15:07:56 -07:00
hagen-danswer
5c12a3e872 brought out the UsersResponse interface (#119) 2024-06-25 15:07:56 -07:00
Weves
3af81ca96b Fix seed config when left empty 2024-06-25 15:07:56 -07:00
Weves
f55e5415bb Add empty assets folder 2024-06-25 15:07:56 -07:00
Weves
3d434c2c9e Fix persona access for answer-with-quote API 2024-06-25 15:07:56 -07:00
pablodanswer
90ec156791 formatting 2024-06-25 15:07:56 -07:00
pablodanswer
8ba48e24a6 minor build fix 2024-06-25 15:07:56 -07:00
pablodanswer
e34bcbbd06 Add persistent name and logo seeding (#107) 2024-06-25 15:07:56 -07:00
pablodanswer
db319168f8 stronger wording 2024-06-25 15:07:56 -07:00
pablodanswer
010ce5395f Minor/ee optional branding (#105) 2024-06-25 15:07:56 -07:00
rashad-danswer
98a58337a7 Query history speed fix (#109) 2024-06-25 15:07:56 -07:00
Weves
733d4e666b Add support for private file connectors 2024-06-25 15:07:56 -07:00
Weves
2937fe9e7d Fix backend build 2024-06-25 15:07:56 -07:00
Weves
457527ac86 Try different runner groups for each build 2024-06-25 15:07:56 -07:00
Weves
7cc51376f2 Allow basic seeding of Danswer via env variable 2024-06-25 15:07:56 -07:00
Weves
7278d45552 Fix rebase issue 2024-06-25 15:07:56 -07:00
Yuhong Sun
1c343bbee7 Enable Dedup Flag for Doc Search Endpoint 2024-06-25 15:07:56 -07:00
Weves
bdcfb39724 Add whitelabeled name to login page 2024-06-25 15:07:56 -07:00
Weves
694d20ea8f Fix user groups issue from rebase 2024-06-25 15:07:56 -07:00
Weves
45402d0755 Add back custom logo/name to sidebar header 2024-06-25 15:07:56 -07:00
Weves
69740ba3d5 Fix rebase issue 2024-06-25 15:07:56 -07:00
Yuhong Sun
6162283beb Fix formatting issues (#93) 2024-06-25 15:07:56 -07:00
Yuhong Sun
44284f7912 Fix Rebase Issues (#92) 2024-06-25 15:07:56 -07:00
Weves
775ca5787b Move web build to a matrix build 2024-06-25 15:07:56 -07:00
Weves
c6e49a3034 Don't get duplicate docs during user group syncing 2024-06-25 15:07:56 -07:00
Weves
9c8cfd9175 Fix mypy 2024-06-25 15:07:56 -07:00
Weves
fc3ed76d12 Add pagination to user group syncing 2024-06-25 15:07:56 -07:00
Weves
a2597d5f21 Fix rebase issue with dev compose file 2024-06-25 15:07:56 -07:00
Yuhong Sun
af588461d2 Enable Encryption 2024-06-25 15:07:56 -07:00
Weves
460e61b3a7 Fix document lock acquisition for user group sync 2024-06-25 15:07:56 -07:00
Weves
c631ac0c3a Change secret name 2024-06-25 15:07:56 -07:00
Yuhong Sun
10be91a8cc Track Slack questions Autoresolved (#86) 2024-06-25 15:07:56 -07:00
Weves
eadad34a77 Fix /send-message-simple-api endpoint 2024-06-25 15:07:56 -07:00
Weves
b19d88a151 Fix rebase issue with file_store 2024-06-25 15:07:56 -07:00
Weves
e33b469915 Remove unused Chat.tsx file 2024-06-25 15:07:56 -07:00
Weves
719fc06604 Fix rebase issue with UI-based LLM selection 2024-06-25 15:07:56 -07:00
Alan Hagedorn
d7a704c0d9 Token Rate Limiting
WIP

Cleanup 🧹

Remove existing rate limiting logic

Cleanup 🧼

Undo nit

Cleanup 🧽

Move db constants (avoids circular import)

WIP

WIP

Cleanup

Lint

Resolve alembic conflict

Fix mypy

Add backfill to migration

Update comment

Make unauthenticated users still adhere to global limits

Use Depends

Remove enum from table

Update migration error handling + deletion

Address verbal feedback, cleanup urls, minor nits
2024-06-25 15:07:56 -07:00
Yuhong Sun
7a408749cf Fix Web Compile Issue (#81) 2024-06-25 15:07:56 -07:00
Yuhong Sun
d9acd03a85 Query History Include Feedback Text (#80) 2024-06-25 15:07:56 -07:00
Yuhong Sun
af94c092e7 Reduce sync jobs batch size (#79) 2024-06-25 15:07:56 -07:00
Yuhong Sun
f55a4ef9bd Remove Nested Session (#78) 2024-06-25 15:07:56 -07:00
Yuhong Sun
6c6e33e001 Allow Empty API Names (#77) 2024-06-25 15:07:56 -07:00
Yuhong Sun
336c046e5d Better Naming for API Keys (#76) 2024-06-25 15:07:56 -07:00
Weves
9a9b89f073 Fix rebase issue with public assistants 2024-06-25 15:07:56 -07:00
Weves
89fac98534 Fix ee redirect 2024-06-25 15:07:56 -07:00
Weves
65b65518de Add back ChatBanner 2024-06-25 15:07:56 -07:00
Yuhong Sun
0c827d1e6c Permission Sync Framework (#44) 2024-06-25 15:07:56 -07:00
Weves
1984f2c1ca Add automated auth checks for ee 2024-06-25 15:07:56 -07:00
Yuhong Sun
50f006557f Add message id to simple message endpoint (#69) 2024-06-25 15:07:56 -07:00
Yuhong Sun
c00bd44bcc Add Chunk Context options for EE APIs (#68) 2024-06-25 15:07:56 -07:00
Yuhong Sun
680aca68e5 Make EE containers public changes (#67) 2024-06-25 15:07:56 -07:00
Weves
22a2f86fb9 FE build fix 2024-06-25 15:07:56 -07:00
Weves
c055dc1535 Add custom analytics script 2024-06-25 15:07:56 -07:00
Alan Hagedorn
81e9880d9d Add names to API Keys (#63) 2024-06-25 15:07:56 -07:00
Weves
3466f6d3a4 Custom banner 2024-06-25 15:07:56 -07:00
Weves
91cf45165f Small fixes + adding 'Powered by Danswer' 2024-06-25 15:07:56 -07:00
Weves
ee2a5bbf49 Add custom-styling ability via themes 2024-06-25 15:07:56 -07:00
Weves
153007c57c Whitelableing for Logo / Name via Admin panel 2024-06-25 15:07:56 -07:00
Yuhong Sun
fa8cc10063 Allow Optional Rerank in APIs (#60) 2024-06-25 15:07:56 -07:00
Yuhong Sun
2c3ba5f021 Include User in Query Export (#59) 2024-06-25 15:07:56 -07:00
Yuhong Sun
e3ef620094 Query History Now Handles Old Messages (#58) 2024-06-25 15:07:56 -07:00
Yuhong Sun
40369e0538 Formatter (#57) 2024-06-25 15:07:56 -07:00
Yuhong Sun
d6c5c65b51 Fix Query History (#56) 2024-06-25 15:07:56 -07:00
Yuhong Sun
7b16cb9562 Rebase search changes to EE APIs (#55) 2024-06-25 15:07:56 -07:00
Weves
ef4f06a375 Fix SAML for /manage/me 2024-06-25 15:07:56 -07:00
Chris Weaver
17cc262f5d Private personas doc sets (#52)
Private Personas and Document Sets

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2024-06-25 15:07:56 -07:00
Yuhong Sun
680482bd06 Metadata filter for document search API (#53) 2024-06-25 15:07:56 -07:00
Weves
64874d2737 Change runner for backend 2024-06-25 15:07:56 -07:00
Weves
00ade322f1 Fix small answer-with-quote bug 2024-06-25 15:07:56 -07:00
Weves
eab5d054d5 Add env variable to control hash rounds 2024-06-25 15:07:56 -07:00
Weves
a09d60d7d0 Fix user group deletion bug 2024-06-25 15:07:56 -07:00
Yuhong Sun
f17dc52b37 One Shot API No Stream (#41) 2024-06-25 15:07:56 -07:00
Yuhong Sun
c1862e961b Simple API No Longer Require Specify Prompt (#40) 2024-06-25 15:07:56 -07:00
Yuhong Sun
6b46a71cb5 Fix Empty Chat for API (#39) 2024-06-25 15:07:56 -07:00
Yuhong Sun
9ae3a4af7f Basic Chat API (#38) 2024-06-25 15:07:56 -07:00
Weves
328b96c9ff Make public/not-public selector prettier 2024-06-25 15:07:56 -07:00
Yuhong Sun
bac34a47b2 Embedding Model Swap Changes (#35) 2024-06-25 15:07:56 -07:00
Weves
15934ee268 Fix limit/offset for document-search endpoint 2024-06-25 15:07:56 -07:00
Weves
fe975c3357 Add global prefix to EE endpoints 2024-06-25 15:07:56 -07:00
Weves
8bf483904d Fix users page with API keys + add spinner on key creation 2024-06-25 15:07:56 -07:00
Yuhong Sun
db338bfddf Introduce EE only Backend APIs (#29) 2024-06-25 15:07:56 -07:00
Weves
ae02a5199a Add API key generation in the UI + allow it to be used across all endpoints 2024-06-25 15:07:56 -07:00
Yuhong Sun
4b44073d9a CVEs (#26) 2024-06-25 15:07:56 -07:00
Weves
ce36530c79 Fix viewing other users' chat histories in query history 2024-06-25 15:07:56 -07:00
Weves
39d69838c5 Make query history fetch client-side 2024-06-25 15:07:56 -07:00
Weves
e11f0f6202 Fix /chat-session-history/{chat_session_id} endpoint when auth is enabled 2024-06-25 15:07:56 -07:00
Weves
ce870ff577 Re-style user group pages 2024-06-25 15:07:56 -07:00
Weves
a52711967f Fix analytics + query history 2024-06-25 15:07:56 -07:00
Weves
67a4eb6f6f Fix frontend typing rebase issue 2024-06-25 15:07:56 -07:00
Weves
9599388db8 Fix sidebar typo 2024-06-25 15:07:56 -07:00
Weves
f82ae158ea Mark indexing jobs as ee when running ee supervisord 2024-06-25 15:07:56 -07:00
Weves
670de6c00d Add new env variable to EE supervisord 2024-06-25 15:07:56 -07:00
Yuhong Sun
56c52bddff Fix missing supervisord change from Danswer MIT (#18) 2024-06-25 15:07:56 -07:00
Chris Weaver
3984350ff9 Improvements to Query History (#17)
* Add option to download query-history as a CSV

* Add user email + more complete timestamp
2024-06-25 15:07:56 -07:00
Weves
f799d9aa11 Fix EE import 2024-06-25 15:07:56 -07:00
Yuhong Sun
529f2c8c2d Danswer EE Version Text (#12) 2024-06-25 15:07:56 -07:00
Yuhong Sun
b4683dc841 Fix Rebase Issue 2024-06-25 15:07:56 -07:00
Chris Weaver
db8ce61ff4 Fix group prefix (#11) 2024-06-25 15:07:56 -07:00
Yuhong Sun
d016e8335e Update default SAML config location (#10) 2024-06-25 15:07:56 -07:00
Yuhong Sun
0c295d1de5 Enable EE features for no-letsencrypt deployment (#9) 2024-06-25 15:07:56 -07:00
Chris Weaver
e9f273d99a Admin Analytics/Query History dashboards (#6) 2024-06-25 15:07:56 -07:00
Weves
428f5edd21 Update ee supervisord 2024-06-25 15:07:56 -07:00
Weves
50170cc97e Move user group syncing to Celery Beat 2024-06-25 15:07:56 -07:00
Chris Weaver
7503f8f37b Add User Groups (a.k.a. RBAC) (#4) 2024-06-25 15:07:56 -07:00
Yuhong Sun
92de6acc6f Initial EE features (#3) 2024-06-25 15:07:56 -07:00
mattboret
65d5808ea7 Confluence: add pages labels indexation (#1635)
* Confluence: add pages labels indexation

* changed the default and fixed the dict building

* Update app_configs.py

* Update connector.py

---------

Co-authored-by: Matthieu Boret <matthieu.boret@fr.clara.net>
Co-authored-by: hagen-danswer <hagen@danswer.ai>
Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2024-06-25 10:41:29 -07:00
Yuhong Sun
061dab7f37 Touchup (#1702) 2024-06-25 10:34:03 -07:00
hagen-danswer
e65d9e155d fixed confluence breaking on unknown filetypes (#1698) 2024-06-25 10:19:01 -07:00
hagen-danswer
50f799edf4 Merge pull request #1672 from danswer-ai/add-groups-to-slack-bot-responses
add slack groups to user response list
2024-06-24 15:10:12 -07:00
pablodanswer
c1d8f6cb66 add basic 403 support for healthcheck (#1689) 2024-06-23 11:19:46 -07:00
pablodanswer
6c71bc05ea modify script deletion name (#1690) 2024-06-23 08:29:37 -07:00
Yuhong Sun
123ec4342a Relari (#1687)
Also includes some bugfixes
2024-06-22 18:52:48 -07:00
pablodanswer
7253316b9e Add script for forced connector deletion (#1683) 2024-06-22 17:15:25 -07:00
Weves
4ae924662c Add migration for usage reports 2024-06-22 16:21:41 -07:00
Yuhong Sun
094eea2742 Discourse Edge Case (#1685) 2024-06-22 15:17:33 -07:00
pablodanswer
8178d536b4 Add functional thread modification endpoints (#1668)
Makes it so if you change which LLM you are using in a given ChatSession, that is persisted and sticks around if you reload the page / come back to the ChatSession later
2024-06-21 18:10:30 -07:00
Yuhong Sun
5cafc96cae Enable Internet Search for Deployment Options (#1684) 2024-06-21 17:38:49 -07:00
Weves
3e39a921b0 Fix image generation output 2024-06-21 15:41:55 -07:00
Weves
98b2507045 Improve persona access 2024-06-21 13:35:14 -07:00
hagen-danswer
3dfe17a54d google drive ignores shortcut filetypes now 2024-06-20 18:50:12 -07:00
hagen-danswer
b4675082b1 Update handle_message.py 2024-06-20 21:16:26 -04:00
hagen-danswer
287a706e89 combined the input fields 2024-06-20 11:48:14 -07:00
Kevin Shi
ba58208a85 Transform HTML links to markdown behind config option (#1671) 2024-06-20 10:43:15 -07:00
hagen-danswer
694e9e8679 finished first draft 2024-06-19 22:11:33 -07:00
pablodanswer
9e30ec1f1f hide popup for non admin + if search is disabled 2024-06-19 13:08:24 -07:00
pablodanswer
1b56c75527 [minor] proper assistant line length 2024-06-19 12:14:42 -07:00
Weves
b07fdbf1d1 Cleanup user management 2024-06-18 14:37:47 -07:00
hagen-danswer
54c2547d89 Add connector document pruning task (#1652) 2024-06-18 14:26:12 -07:00
Liam Norris
58b5e25c97 User Management: Invite, Deactivate, Search, & Paginate (#1631) 2024-06-18 11:28:47 -07:00
hagen-danswer
4e15ba78d5 replicated drive fix for gmail connector (#1658) 2024-06-18 08:47:06 -07:00
Yuhong Sun
c798ade127 Code for ease of eval (#1656) 2024-06-17 20:32:12 -07:00
573 changed files with 39722 additions and 7213 deletions

View File

@@ -0,0 +1,33 @@
name: Build Backend Image on Merge Group
on:
merge_group:
types: [checks_requested]
env:
REGISTRY_IMAGE: danswer/danswer-backend
jobs:
build:
# TODO: make this a matrix build like the web containers
runs-on:
group: amd64-image-builders
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Backend Image Docker Build
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64,linux/arm64
push: false
tags: |
${{ env.REGISTRY_IMAGE }}:latest
build-args: |
DANSWER_VERSION=v0.0.1

View File

@@ -5,9 +5,14 @@ on:
tags:
- '*'
env:
REGISTRY_IMAGE: danswer/danswer-backend
jobs:
build-and-push:
runs-on: ubuntu-latest
# TODO: make this a matrix build like the web containers
runs-on:
group: amd64-image-builders
steps:
- name: Checkout code
@@ -30,8 +35,8 @@ jobs:
platforms: linux/amd64,linux/arm64
push: true
tags: |
danswer/danswer-backend:${{ github.ref_name }}
danswer/danswer-backend:latest
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.REGISTRY_IMAGE }}:latest
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
@@ -39,6 +44,6 @@ jobs:
uses: aquasecurity/trivy-action@master
with:
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
image-ref: docker.io/danswer/danswer-backend:${{ github.ref_name }}
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'
trivyignores: ./backend/.trivyignore

View File

@@ -10,7 +10,7 @@ env:
jobs:
build:
runs-on:
runs-on:
group: ${{ matrix.platform == 'linux/amd64' && 'amd64-image-builders' || 'arm64-image-builders' }}
strategy:
fail-fast: false
@@ -34,8 +34,8 @@ jobs:
with:
images: ${{ env.REGISTRY_IMAGE }}
tags: |
type=raw,value=danswer/danswer-web-server:${{ github.ref_name }}
type=raw,value=danswer/danswer-web-server:latest
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
type=raw,value=${{ env.REGISTRY_IMAGE }}:latest
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

View File

@@ -0,0 +1,53 @@
name: Build Web Image on Merge Group
on:
merge_group:
types: [checks_requested]
env:
REGISTRY_IMAGE: danswer/danswer-web-server
jobs:
build:
runs-on:
group: ${{ matrix.platform == 'linux/amd64' && 'amd64-image-builders' || 'arm64-image-builders' }}
strategy:
fail-fast: false
matrix:
platform:
- linux/amd64
- linux/arm64
steps:
- name: Prepare
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
tags: |
type=raw,value=${{ env.REGISTRY_IMAGE }}:latest
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build by digest
id: build
uses: docker/build-push-action@v5
with:
context: ./web
file: ./web/Dockerfile
platforms: ${{ matrix.platform }}
push: false
build-args: |
DANSWER_VERSION=v0.0.1
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

2
.gitignore vendored
View File

@@ -5,3 +5,5 @@
.idea
/deployment/data/nginx/app.conf
.vscode/launch.json
*.sw?
/backend/tests/regression/answer_quality/search_test_config.yaml

View File

@@ -4,17 +4,20 @@
# For local dev, often user Authentication is not needed
AUTH_TYPE=disabled
# This passes top N results to LLM an additional time for reranking prior to answer generation, quite token heavy so we disable it for dev generally
DISABLE_LLM_CHUNK_FILTER=True
# Always keep these on for Dev
# Logs all model prompts to stdout
LOG_ALL_MODEL_INTERACTIONS=True
LOG_DANSWER_MODEL_INTERACTIONS=True
# More verbose logging
LOG_LEVEL=debug
# This passes top N results to LLM an additional time for reranking prior to answer generation
# This step is quite heavy on token usage so we disable it for dev generally
DISABLE_LLM_CHUNK_FILTER=True
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)
OAUTH_CLIENT_ID=<REPLACE THIS>
OAUTH_CLIENT_SECRET=<REPLACE THIS>
@@ -22,10 +25,6 @@ OAUTH_CLIENT_SECRET=<REPLACE THIS>
REQUIRE_EMAIL_VERIFICATION=False
# Toggles on/off the EE Features
NEXT_PUBLIC_ENABLE_PAID_EE_FEATURES=False
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
GEN_AI_API_KEY=<REPLACE THIS>
# If answer quality isn't important for dev, use 3.5 turbo due to it being cheaper
@@ -41,3 +40,13 @@ FAST_GEN_AI_MODEL_VERSION=gpt-3.5-turbo
# Python stuff
PYTHONPATH=./backend
PYTHONUNBUFFERED=1
# Internet Search
BING_API_KEY=<REPLACE THIS>
# Enable the full set of Danswer Enterprise Edition features
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False

View File

@@ -17,6 +17,7 @@
"request": "launch",
"cwd": "${workspaceRoot}/web",
"runtimeExecutable": "npm",
"envFile": "${workspaceFolder}/.env",
"runtimeArgs": [
"run", "dev"
],
@@ -28,6 +29,7 @@
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
@@ -45,8 +47,9 @@
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_ALL_MODEL_INTERACTIONS": "True",
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
},
@@ -63,6 +66,7 @@
"request": "launch",
"program": "danswer/background/update.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"env": {
"ENABLE_MINI_CHUNK": "false",
"LOG_LEVEL": "DEBUG",
@@ -77,7 +81,9 @@
"request": "launch",
"program": "scripts/dev_run_background_jobs.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
@@ -100,6 +106,24 @@
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
}
},
{
"name": "Pytest",
"type": "python",
"request": "launch",
"module": "pytest",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-v"
// Specify a sepcific module/test to run or provide nothing to run all tests
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
]
}
]
}
}

View File

@@ -72,6 +72,10 @@ For convenience here's a command for it:
python -m venv .venv
source .venv/bin/activate
```
--> Note that this virtual environment MUST NOT be set up WITHIN the danswer
directory
_For Windows, activate the virtual environment using Command Prompt:_
```bash
.venv\Scripts\activate

View File

@@ -1,6 +1,10 @@
MIT License
Copyright (c) 2023-present DanswerAI, Inc.
Copyright (c) 2023 Yuhong Sun, Chris Weaver
Portions of this software are licensed as follows:
* All content that resides under "ee" directories of this repository, if that directory exists, is licensed under the license defined in "backend/ee/LICENSE". Specifically all content under "backend/ee" and "web/src/app/ee" is licensed under the license defined in "backend/ee/LICENSE".
* All third party components incorporated into the Danswer Software are licensed under the original license provided by the owner of the applicable component.
* Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View File

@@ -11,7 +11,7 @@
<a href="https://docs.danswer.dev/" target="_blank">
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
</a>
<a href="https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ" target="_blank">
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ" target="_blank">
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
</a>
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
@@ -105,5 +105,25 @@ Efficiently pulls the latest changes from:
* Websites
* And more ...
## 📚 Editions
There are two editions of Danswer:
* Danswer Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Danswer you will get if you follow the Deployment guide above.
* Danswer Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
* Single Sign-On (SSO), with support for both SAML and OIDC
* Role-based access control
* Document permission inheritance from connected sources
* Usage analytics and query history accessible to admins
* Whitelabeling
* API key authentication
* Encryption of secrets
* Any many more! Checkout [our website](https://www.danswer.ai/) for the latest.
To try the Danswer Enterprise Edition:
1. Checkout our [Cloud product](https://app.danswer.ai/signup).
2. For self-hosting, contact us at [founders@danswer.ai](mailto:founders@danswer.ai) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.

2
backend/.gitignore vendored
View File

@@ -5,7 +5,7 @@ site_crawls/
.ipynb_checkpoints/
api_keys.py
*ipynb
.env
.env*
vespa-app.zip
dynamic_config_storage/
celerybeat-schedule*

View File

@@ -1,15 +1,17 @@
FROM python:3.11.7-slim-bookworm
LABEL com.danswer.maintainer="founders@danswer.ai"
LABEL com.danswer.description="This image is for the backend of Danswer. It is MIT Licensed and \
free for all to use. You can find it at https://hub.docker.com/r/danswer/danswer-backend. For \
more details, visit https://github.com/danswer-ai/danswer."
LABEL com.danswer.description="This image is the web/frontend container of Danswer which \
contains code for both the Community and Enterprise editions of Danswer. If you do not \
have a contract or agreement with DanswerAI, you are not permitted to use the Enterprise \
Edition features outside of personal development or testing purposes. Please reach out to \
founders@danswer.ai for more information. Please visit https://github.com/danswer-ai/danswer"
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
ARG DANSWER_VERSION=0.3-dev
ENV DANSWER_VERSION=${DANSWER_VERSION}
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
# Install system dependencies
# cmake needed for psycopg (postgres)
# libpq-dev needed for psycopg (postgres)
@@ -17,18 +19,32 @@ RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
# zip for Vespa step futher down
# ca-certificates for HTTPS
RUN apt-get update && \
apt-get install -y cmake curl zip ca-certificates libgnutls30=3.7.9-2+deb12u2 \
libblkid1=2.38.1-5+deb12u1 libmount1=2.38.1-5+deb12u1 libsmartcols1=2.38.1-5+deb12u1 \
libuuid1=2.38.1-5+deb12u1 && \
apt-get install -y \
cmake \
curl \
zip \
ca-certificates \
libgnutls30=3.7.9-2+deb12u3 \
libblkid1=2.38.1-5+deb12u1 \
libmount1=2.38.1-5+deb12u1 \
libsmartcols1=2.38.1-5+deb12u1 \
libuuid1=2.38.1-5+deb12u1 \
libxmlsec1-dev \
pkg-config \
gcc && \
rm -rf /var/lib/apt/lists/* && \
apt-get clean
# Install Python dependencies
# Remove py which is pulled in by retry, py is not needed and is a CVE
COPY ./requirements/default.txt /tmp/requirements.txt
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \
COPY ./requirements/ee.txt /tmp/ee-requirements.txt
RUN pip install --no-cache-dir --upgrade \
-r /tmp/requirements.txt \
-r /tmp/ee-requirements.txt && \
pip uninstall -y py && \
playwright install chromium && playwright install-deps chromium && \
playwright install chromium && \
playwright install-deps chromium && \
ln -s /usr/local/bin/supervisord /usr/bin/supervisord
# Cleanup for CVEs and size reduction
@@ -36,11 +52,20 @@ RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \
# xserver-common and xvfb included by playwright installation but not needed after
# perl-base is part of the base Python Debian image but not needed for Danswer functionality
# perl-base could only be removed with --allow-remove-essential
RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cmake \
libldap-2.5-0 libldap-2.5-0 && \
RUN apt-get update && \
apt-get remove -y --allow-remove-essential \
perl-base \
xserver-common \
xvfb \
cmake \
libldap-2.5-0 \
libxmlsec1-dev \
pkg-config \
gcc && \
apt-get install -y libxmlsec1-openssl && \
apt-get autoremove -y && \
rm -rf /var/lib/apt/lists/* && \
rm /usr/local/lib/python3.11/site-packages/tornado/test/test.key
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
# Pre-downloading models for setups with limited egress
RUN python -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('intfloat/e5-base-v2')"
@@ -53,12 +78,24 @@ nltk.download('punkt', quiet=True);"
# Set up application files
WORKDIR /app
# Enterprise Version Files
COPY ./ee /app/ee
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
# Set up application files
COPY ./danswer /app/danswer
COPY ./shared_configs /app/shared_configs
COPY ./alembic /app/alembic
COPY ./alembic.ini /app/alembic.ini
COPY supervisord.conf /usr/etc/supervisord.conf
# Escape hatch
COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
# Put logo in assets
COPY ./assets /app/assets
ENV PYTHONPATH /app
# Default command which does nothing

View File

@@ -0,0 +1,31 @@
"""Add thread specific model selection
Revision ID: 0568ccf46a6b
Revises: e209dc5a8156
Create Date: 2024-06-19 14:25:36.376046
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0568ccf46a6b"
down_revision = "e209dc5a8156"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"chat_session",
sa.Column("current_alternate_model", sa.String(), nullable=True),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("chat_session", "current_alternate_model")
# ### end Alembic commands ###

View File

@@ -0,0 +1,32 @@
"""add search doc relevance details
Revision ID: 05c07bf07c00
Revises: b896bbd0d5a7
Create Date: 2024-07-10 17:48:15.886653
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "05c07bf07c00"
down_revision = "b896bbd0d5a7"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"search_doc",
sa.Column("is_relevant", sa.Boolean(), nullable=True),
)
op.add_column(
"search_doc",
sa.Column("relevance_explanation", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("search_doc", "relevance_explanation")
op.drop_column("search_doc", "is_relevant")

View File

@@ -0,0 +1,86 @@
"""remove-feedback-foreignkey-constraint
Revision ID: 23957775e5f5
Revises: bc9771dccadf
Create Date: 2024-06-27 16:04:51.480437
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "23957775e5f5"
down_revision = "bc9771dccadf"
branch_labels = None # type: ignore
depends_on = None # type: ignore
def upgrade() -> None:
op.drop_constraint(
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
)
op.create_foreign_key(
"chat_feedback__chat_message_fk",
"chat_feedback",
"chat_message",
["chat_message_id"],
["id"],
ondelete="SET NULL",
)
op.alter_column(
"chat_feedback", "chat_message_id", existing_type=sa.Integer(), nullable=True
)
op.drop_constraint(
"document_retrieval_feedback__chat_message_fk",
"document_retrieval_feedback",
type_="foreignkey",
)
op.create_foreign_key(
"document_retrieval_feedback__chat_message_fk",
"document_retrieval_feedback",
"chat_message",
["chat_message_id"],
["id"],
ondelete="SET NULL",
)
op.alter_column(
"document_retrieval_feedback",
"chat_message_id",
existing_type=sa.Integer(),
nullable=True,
)
def downgrade() -> None:
op.alter_column(
"chat_feedback", "chat_message_id", existing_type=sa.Integer(), nullable=False
)
op.drop_constraint(
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
)
op.create_foreign_key(
"chat_feedback__chat_message_fk",
"chat_feedback",
"chat_message",
["chat_message_id"],
["id"],
)
op.alter_column(
"document_retrieval_feedback",
"chat_message_id",
existing_type=sa.Integer(),
nullable=False,
)
op.drop_constraint(
"document_retrieval_feedback__chat_message_fk",
"document_retrieval_feedback",
type_="foreignkey",
)
op.create_foreign_key(
"document_retrieval_feedback__chat_message_fk",
"document_retrieval",
"chat_message",
["chat_message_id"],
["id"],
)

View File

@@ -0,0 +1,38 @@
"""add alternate assistant to chat message
Revision ID: 3a7802814195
Revises: 23957775e5f5
Create Date: 2024-06-05 11:18:49.966333
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "3a7802814195"
down_revision = "23957775e5f5"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True)
)
op.create_foreign_key(
"fk_chat_message_persona",
"chat_message",
"persona",
["alternate_assistant_id"],
["id"],
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("fk_chat_message_persona", "chat_message", type_="foreignkey")
op.drop_column("chat_message", "alternate_assistant_id")

View File

@@ -0,0 +1,65 @@
"""add cloud embedding model and update embedding_model
Revision ID: 44f856ae2a4a
Revises: d716b0791ddd
Create Date: 2024-06-28 20:01:05.927647
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "44f856ae2a4a"
down_revision = "d716b0791ddd"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# Create embedding_provider table
op.create_table(
"embedding_provider",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("api_key", sa.LargeBinary(), nullable=True),
sa.Column("default_model_id", sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
# Add cloud_provider_id to embedding_model table
op.add_column(
"embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True)
)
# Add foreign key constraints
op.create_foreign_key(
"fk_embedding_model_cloud_provider",
"embedding_model",
"embedding_provider",
["cloud_provider_id"],
["id"],
)
op.create_foreign_key(
"fk_embedding_provider_default_model",
"embedding_provider",
"embedding_model",
["default_model_id"],
["id"],
)
def downgrade() -> None:
# Remove foreign key constraints
op.drop_constraint(
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
)
op.drop_constraint(
"fk_embedding_provider_default_model", "embedding_provider", type_="foreignkey"
)
# Remove cloud_provider_id column
op.drop_column("embedding_model", "cloud_provider_id")
# Drop embedding_provider table
op.drop_table("embedding_provider")

View File

@@ -0,0 +1,23 @@
"""added is_internet to DBDoc
Revision ID: 4505fd7302e1
Revises: c18cdf4b497e
Create Date: 2024-06-18 20:46:09.095034
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4505fd7302e1"
down_revision = "c18cdf4b497e"
def upgrade() -> None:
op.add_column("search_doc", sa.Column("is_internet", sa.Boolean(), nullable=True))
op.add_column("tool", sa.Column("display_name", sa.String(), nullable=True))
def downgrade() -> None:
op.drop_column("tool", "display_name")
op.drop_column("search_doc", "is_internet")

View File

@@ -13,8 +13,8 @@ from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "48d14957fe80"
down_revision = "b85f02ec1308"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@@ -0,0 +1,35 @@
"""added slack_auto_filter
Revision ID: 7aea705850d5
Revises: 4505fd7302e1
Create Date: 2024-07-10 11:01:23.581015
"""
from alembic import op
import sqlalchemy as sa
revision = "7aea705850d5"
down_revision = "4505fd7302e1"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"slack_bot_config",
sa.Column("enable_auto_filters", sa.Boolean(), nullable=True),
)
op.execute(
"UPDATE slack_bot_config SET enable_auto_filters = FALSE WHERE enable_auto_filters IS NULL"
)
op.alter_column(
"slack_bot_config",
"enable_auto_filters",
existing_type=sa.Boolean(),
nullable=False,
server_default=sa.false(),
)
def downgrade() -> None:
op.drop_column("slack_bot_config", "enable_auto_filters")

View File

@@ -0,0 +1,23 @@
"""backfill is_internet data to False
Revision ID: b896bbd0d5a7
Revises: 44f856ae2a4a
Create Date: 2024-07-16 15:21:05.718571
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "b896bbd0d5a7"
down_revision = "44f856ae2a4a"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.execute("UPDATE search_doc SET is_internet = FALSE WHERE is_internet IS NULL")
def downgrade() -> None:
pass

View File

@@ -0,0 +1,51 @@
"""create usage reports table
Revision ID: bc9771dccadf
Revises: 0568ccf46a6b
Create Date: 2024-06-18 10:04:26.800282
"""
from alembic import op
import sqlalchemy as sa
import fastapi_users_db_sqlalchemy
# revision identifiers, used by Alembic.
revision = "bc9771dccadf"
down_revision = "0568ccf46a6b"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"usage_reports",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("report_name", sa.String(), nullable=False),
sa.Column(
"requestor_user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("period_from", sa.DateTime(timezone=True), nullable=True),
sa.Column("period_to", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(
["report_name"],
["file_store.file_name"],
),
sa.ForeignKeyConstraint(
["requestor_user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
def downgrade() -> None:
op.drop_table("usage_reports")

View File

@@ -0,0 +1,75 @@
"""Add standard_answer tables
Revision ID: c18cdf4b497e
Revises: 3a7802814195
Create Date: 2024-06-06 15:15:02.000648
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c18cdf4b497e"
down_revision = "3a7802814195"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"standard_answer",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("keyword", sa.String(), nullable=False),
sa.Column("answer", sa.String(), nullable=False),
sa.Column("active", sa.Boolean(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("keyword"),
)
op.create_table(
"standard_answer_category",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
op.create_table(
"standard_answer__standard_answer_category",
sa.Column("standard_answer_id", sa.Integer(), nullable=False),
sa.Column("standard_answer_category_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["standard_answer_category_id"],
["standard_answer_category.id"],
),
sa.ForeignKeyConstraint(
["standard_answer_id"],
["standard_answer.id"],
),
sa.PrimaryKeyConstraint("standard_answer_id", "standard_answer_category_id"),
)
op.create_table(
"slack_bot_config__standard_answer_category",
sa.Column("slack_bot_config_id", sa.Integer(), nullable=False),
sa.Column("standard_answer_category_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["slack_bot_config_id"],
["slack_bot_config.id"],
),
sa.ForeignKeyConstraint(
["standard_answer_category_id"],
["standard_answer_category.id"],
),
sa.PrimaryKeyConstraint("slack_bot_config_id", "standard_answer_category_id"),
)
op.add_column(
"chat_session", sa.Column("slack_thread_id", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("chat_session", "slack_thread_id")
op.drop_table("slack_bot_config__standard_answer_category")
op.drop_table("standard_answer__standard_answer_category")
op.drop_table("standard_answer_category")
op.drop_table("standard_answer")

View File

@@ -0,0 +1,45 @@
"""combined slack id fields
Revision ID: d716b0791ddd
Revises: 7aea705850d5
Create Date: 2024-07-10 17:57:45.630550
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "d716b0791ddd"
down_revision = "7aea705850d5"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.execute(
"""
UPDATE slack_bot_config
SET channel_config = jsonb_set(
channel_config,
'{respond_member_group_list}',
coalesce(channel_config->'respond_team_member_list', '[]'::jsonb) ||
coalesce(channel_config->'respond_slack_group_list', '[]'::jsonb)
) - 'respond_team_member_list' - 'respond_slack_group_list'
"""
)
def downgrade() -> None:
op.execute(
"""
UPDATE slack_bot_config
SET channel_config = jsonb_set(
jsonb_set(
channel_config - 'respond_member_group_list',
'{respond_team_member_list}',
'[]'::jsonb
),
'{respond_slack_group_list}',
'[]'::jsonb
)
"""
)

View File

@@ -0,0 +1,22 @@
"""added-prune-frequency
Revision ID: e209dc5a8156
Revises: 48d14957fe80
Create Date: 2024-06-16 16:02:35.273231
"""
from alembic import op
import sqlalchemy as sa
revision = "e209dc5a8156"
down_revision = "48d14957fe80"
branch_labels = None # type: ignore
depends_on = None # type: ignore
def upgrade() -> None:
op.add_column("connector", sa.Column("prune_freq", sa.Integer(), nullable=True))
def downgrade() -> None:
op.drop_column("connector", "prune_freq")

2
backend/assets/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
*
!.gitignore

View File

@@ -1,6 +1,7 @@
from sqlalchemy.orm import Session
from danswer.access.models import DocumentAccess
from danswer.access.utils import prefix_user
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.db.document import get_acccess_info_for_documents
from danswer.db.models import User
@@ -19,7 +20,7 @@ def _get_access_for_documents(
cc_pair_to_delete=cc_pair_to_delete,
)
return {
document_id: DocumentAccess.build(user_ids, is_public)
document_id: DocumentAccess.build(user_ids, [], is_public)
for document_id, user_ids, is_public in document_access_info
}
@@ -38,12 +39,6 @@ def get_access_for_documents(
) # type: ignore
def prefix_user(user_id: str) -> str:
"""Prefixes a user ID to eliminate collision with group names.
This assumes that groups are prefixed with a different prefix."""
return f"user_id:{user_id}"
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
"""Returns a list of ACL entries that the user has access to. This is meant to be
used downstream to filter out documents that the user does not have access to. The

View File

@@ -1,20 +1,30 @@
from dataclasses import dataclass
from uuid import UUID
from danswer.access.utils import prefix_user
from danswer.access.utils import prefix_user_group
from danswer.configs.constants import PUBLIC_DOC_PAT
@dataclass(frozen=True)
class DocumentAccess:
user_ids: set[str] # stringified UUIDs
user_groups: set[str] # names of user groups associated with this document
is_public: bool
def to_acl(self) -> list[str]:
return list(self.user_ids) + ([PUBLIC_DOC_PAT] if self.is_public else [])
return (
[prefix_user(user_id) for user_id in self.user_ids]
+ [prefix_user_group(group_name) for group_name in self.user_groups]
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
)
@classmethod
def build(cls, user_ids: list[UUID | None], is_public: bool) -> "DocumentAccess":
def build(
cls, user_ids: list[UUID | None], user_groups: list[str], is_public: bool
) -> "DocumentAccess":
return cls(
user_ids={str(user_id) for user_id in user_ids if user_id},
user_groups=set(user_groups),
is_public=is_public,
)

View File

@@ -0,0 +1,10 @@
def prefix_user(user_id: str) -> str:
"""Prefixes a user ID to eliminate collision with group names.
This assumes that groups are prefixed with a different prefix."""
return f"user_id:{user_id}"
def prefix_user_group(user_group_name: str) -> str:
"""Prefixes a user group name to eliminate collision with user IDs.
This assumes that user ids are prefixed with a different prefix."""
return f"group:{user_group_name}"

View File

@@ -0,0 +1,21 @@
from typing import cast
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.dynamic_configs.interface import JSON_ro
USER_STORE_KEY = "INVITED_USERS"
def get_invited_users() -> list[str]:
try:
store = get_dynamic_config_store()
return cast(list, store.load(USER_STORE_KEY))
except ConfigNotFoundError:
return list()
def write_invited_users(emails: list[str]) -> int:
store = get_dynamic_config_store()
store.store(USER_STORE_KEY, cast(JSON_ro, emails))
return len(emails)

View File

@@ -9,6 +9,12 @@ class UserRole(str, Enum):
ADMIN = "admin"
class UserStatus(str, Enum):
LIVE = "live"
INVITED = "invited"
DEACTIVATED = "deactivated"
class UserRead(schemas.BaseUser[uuid.UUID]):
role: UserRole

View File

@@ -1,4 +1,3 @@
import os
import smtplib
import uuid
from collections.abc import AsyncGenerator
@@ -27,6 +26,7 @@ from fastapi_users.openapi import OpenAPIResponseType
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
from sqlalchemy.orm import Session
from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
from danswer.configs.app_configs import AUTH_TYPE
@@ -46,6 +46,7 @@ from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.engine import get_session
@@ -54,14 +55,13 @@ from danswer.db.models import User
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import (
fetch_versioned_implementation,
)
logger = setup_logger()
USER_WHITELIST_FILE = "/home/danswer_whitelist.txt"
_user_whitelist: list[str] | None = None
def verify_auth_setting() -> None:
if AUTH_TYPE not in [AuthType.DISABLED, AuthType.BASIC, AuthType.GOOGLE_OAUTH]:
@@ -92,20 +92,8 @@ def user_needs_to_be_verified() -> bool:
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
def get_user_whitelist() -> list[str]:
global _user_whitelist
if _user_whitelist is None:
if os.path.exists(USER_WHITELIST_FILE):
with open(USER_WHITELIST_FILE, "r") as file:
_user_whitelist = [line.strip() for line in file]
else:
_user_whitelist = []
return _user_whitelist
def verify_email_in_whitelist(email: str) -> None:
whitelist = get_user_whitelist()
whitelist = get_invited_users()
if (whitelist and email not in whitelist) or not email:
raise PermissionError("User not on allowed user whitelist")
@@ -163,7 +151,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
verify_email_domain(user_create.email)
if hasattr(user_create, "role"):
user_count = await get_user_count()
if user_count == 0:
if user_count == 0 or user_create.email in get_default_admin_user_emails():
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC

View File

@@ -4,13 +4,22 @@ from typing import cast
from celery import Celery # type: ignore
from sqlalchemy.orm import Session
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.celery_utils import should_prune_cc_pair
from danswer.background.celery.celery_utils import should_sync_doc_set
from danswer.background.connector_deletion import delete_connector_credential_pair
from danswer.background.connector_deletion import delete_connector_credential_pair_batch
from danswer.background.task_utils import build_celery_task_wrapper
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import InputType
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import delete_document_set
from danswer.db.document_set import fetch_document_sets
@@ -22,8 +31,6 @@ from danswer.db.engine import build_connection_string
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import SYNC_DB_API
from danswer.db.models import DocumentSet
from danswer.db.tasks import check_live_task_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import UpdateRequest
@@ -90,6 +97,74 @@ def cleanup_connector_credential_pair_task(
raise e
@build_celery_task_wrapper(name_cc_prune_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def prune_documents_task(connector_id: int, credential_id: int) -> None:
"""connector pruning task. For a cc pair, this task pulls all docuement IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
with Session(get_sqlalchemy_engine()) as db_session:
try:
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair:
logger.warning(f"ccpair not found for {connector_id} {credential_id}")
return
runnable_connector = instantiate_connector(
cc_pair.connector.source,
InputType.PRUNE,
cc_pair.connector.connector_specific_config,
cc_pair.credential,
db_session,
)
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
runnable_connector
)
all_indexed_document_ids = {
doc.id
for doc in get_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
}
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
if len(doc_ids_to_remove) == 0:
logger.info(
f"No docs to prune from {cc_pair.connector.source} connector"
)
return
logger.info(
f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector"
)
delete_connector_credential_pair_batch(
document_ids=doc_ids_to_remove,
connector_id=connector_id,
credential_id=credential_id,
document_index=document_index,
)
except Exception as e:
logger.exception(
f"Failed to run pruning for connector id {connector_id} due to {e}"
)
raise e
@build_celery_task_wrapper(name_document_set_sync_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_document_set_task(document_set_id: int) -> None:
@@ -177,32 +252,48 @@ def sync_document_set_task(document_set_id: int) -> None:
soft_time_limit=JOB_TIMEOUT,
)
def check_for_document_sets_sync_task() -> None:
"""Runs periodically to check if any document sets are out of sync
Creates a task to sync the set if needed"""
"""Runs periodically to check if any sync tasks should be run and adds them
to the queue"""
with Session(get_sqlalchemy_engine()) as db_session:
# check if any document sets are not synced
document_set_info = fetch_document_sets(
user_id=None, db_session=db_session, include_outdated=True
)
for document_set, _ in document_set_info:
if not document_set.is_up_to_date:
task_name = name_document_set_sync_task(document_set.id)
latest_sync = get_latest_task(task_name, db_session)
if latest_sync and check_live_task_not_timed_out(
latest_sync, db_session
):
logger.info(
f"Document set '{document_set.id}' is already syncing. Skipping."
)
continue
logger.info(f"Document set {document_set.id} syncing now!")
if should_sync_doc_set(document_set, db_session):
logger.info(f"Syncing the {document_set.name} document set")
sync_document_set_task.apply_async(
kwargs=dict(document_set_id=document_set.id),
)
@celery_app.task(
name="check_for_prune_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_prune_task() -> None:
"""Runs periodically to check if any prune tasks should be run and adds them
to the queue"""
with Session(get_sqlalchemy_engine()) as db_session:
all_cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in all_cc_pairs:
if should_prune_cc_pair(
connector=cc_pair.connector,
credential=cc_pair.credential,
db_session=db_session,
):
logger.info(f"Pruning the {cc_pair.connector.name} connector")
prune_documents_task.apply_async(
kwargs=dict(
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
)
#####
# Celery Beat (Periodic Tasks) Settings
#####
@@ -212,3 +303,11 @@ celery_app.conf.beat_schedule = {
"schedule": timedelta(seconds=5),
},
}
celery_app.conf.beat_schedule.update(
{
"check-for-prune": {
"task": "check_for_prune_task",
"schedule": timedelta(seconds=5),
},
}
)

View File

@@ -0,0 +1,9 @@
"""Entry point for running celery worker / celery beat."""
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
celery_app = fetch_versioned_implementation(
"danswer.background.celery.celery_app", "celery_app"
)

View File

@@ -1,8 +1,32 @@
from datetime import datetime
from datetime import timezone
from sqlalchemy.orm import Session
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.configs.app_configs import PREVENT_SIMULTANEOUS_PRUNING
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from danswer.connectors.interfaces import BaseConnector
from danswer.connectors.interfaces import IdConnector
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import Document
from danswer.db.engine import get_db_current_time
from danswer.db.models import Connector
from danswer.db.models import Credential
from danswer.db.models import DocumentSet
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.db.tasks import get_latest_task_by_type
from danswer.server.documents.models import DeletionAttemptSnapshot
from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_deletion_status(
@@ -21,3 +45,94 @@ def get_deletion_status(
credential_id=credential_id,
status=task_state.status,
)
def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool:
if document_set.is_up_to_date:
return False
task_name = name_document_set_sync_task(document_set.id)
latest_sync = get_latest_task(task_name, db_session)
if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session):
logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.")
return False
logger.info(f"Document set {document_set.id} syncing now!")
return True
def should_prune_cc_pair(
connector: Connector, credential: Credential, db_session: Session
) -> bool:
if not connector.prune_freq:
return False
pruning_task_name = name_cc_prune_task(
connector_id=connector.id, credential_id=credential.id
)
last_pruning_task = get_latest_task(pruning_task_name, db_session)
current_db_time = get_db_current_time(db_session)
if not last_pruning_task:
time_since_initialization = current_db_time - connector.time_created
if time_since_initialization.total_seconds() >= connector.prune_freq:
return True
return False
if PREVENT_SIMULTANEOUS_PRUNING:
pruning_type_task_name = name_cc_prune_task()
last_pruning_type_task = get_latest_task_by_type(
pruning_type_task_name, db_session
)
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
last_pruning_type_task, db_session
):
logger.info("Another Connector is already pruning. Skipping.")
return False
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
logger.info(f"Connector '{connector.name}' is already pruning. Skipping.")
return False
if not last_pruning_task.start_time:
return False
time_since_last_pruning = current_db_time - last_pruning_task.start_time
return time_since_last_pruning.total_seconds() >= connector.prune_freq
def document_batch_to_ids(doc_batch: list[Document]) -> set[str]:
return {doc.id for doc in doc_batch}
def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> set[str]:
"""
If the PruneConnector hasnt been implemented for the given connector, just pull
all docs using the load_from_state and grab out the IDs
"""
all_connector_doc_ids: set[str] = set()
doc_batch_generator = None
if isinstance(runnable_connector, IdConnector):
all_connector_doc_ids = runnable_connector.retrieve_all_source_ids()
elif isinstance(runnable_connector, LoadConnector):
doc_batch_generator = runnable_connector.load_from_state()
elif isinstance(runnable_connector, PollConnector):
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime.now(timezone.utc).timestamp()
doc_batch_generator = runnable_connector.poll_source(start=start, end=end)
else:
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
if doc_batch_generator:
doc_batch_processing_func = document_batch_to_ids
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
doc_batch_processing_func = rate_limit_builder(
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
)(document_batch_to_ids)
for doc_batch in doc_batch_generator:
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
return all_connector_doc_ids

View File

@@ -41,7 +41,7 @@ logger = setup_logger()
_DELETION_BATCH_SIZE = 1000
def _delete_connector_credential_pair_batch(
def delete_connector_credential_pair_batch(
document_ids: list[str],
connector_id: int,
credential_id: int,
@@ -169,7 +169,7 @@ def delete_connector_credential_pair(
if not documents:
break
_delete_connector_credential_pair_batch(
delete_connector_credential_pair_batch(
document_ids=[document.id for document in documents],
connector_id=connector_id,
credential_id=credential_id,

View File

@@ -105,7 +105,9 @@ class SimpleJobClient:
"""NOTE: `pure` arg is needed so this can be a drop in replacement for Dask"""
self._cleanup_completed_jobs()
if len(self.jobs) >= self.n_workers:
logger.debug("No available workers to run job")
logger.debug(
f"No available workers to run job. Currently running '{len(self.jobs)}' jobs, with a limit of '{self.n_workers}'."
)
return None
job_id = self.job_id_counter

View File

@@ -6,11 +6,7 @@ from datetime import timezone
from sqlalchemy.orm import Session
from danswer.background.connector_deletion import (
_delete_connector_credential_pair_batch,
)
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
from danswer.configs.app_configs import DISABLE_DOCUMENT_CLEANUP
from danswer.configs.app_configs import POLL_CONNECTOR_OFFSET
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.interfaces import GenerateDocumentsOutput
@@ -21,8 +17,6 @@ from danswer.connectors.models import InputType
from danswer.db.connector import disable_connector
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.credentials import backend_update_credential_json
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
@@ -37,6 +31,7 @@ from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import IndexAttemptSingleton
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
logger = setup_logger()
@@ -46,7 +41,7 @@ def _get_document_generator(
attempt: IndexAttempt,
start_time: datetime,
end_time: datetime,
) -> tuple[GenerateDocumentsOutput, bool]:
) -> GenerateDocumentsOutput:
"""
NOTE: `start_time` and `end_time` are only used for poll connectors
@@ -57,16 +52,13 @@ def _get_document_generator(
task = attempt.connector.input_type
try:
runnable_connector, new_credential_json = instantiate_connector(
runnable_connector = instantiate_connector(
attempt.connector.source,
task,
attempt.connector.connector_specific_config,
attempt.credential.credential_json,
attempt.credential,
db_session,
)
if new_credential_json is not None:
backend_update_credential_json(
attempt.credential, new_credential_json, db_session
)
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")
disable_connector(attempt.connector.id, db_session)
@@ -75,7 +67,7 @@ def _get_document_generator(
if task == InputType.LOAD_STATE:
assert isinstance(runnable_connector, LoadConnector)
doc_batch_generator = runnable_connector.load_from_state()
is_listing_complete = True
elif task == InputType.POLL:
assert isinstance(runnable_connector, PollConnector)
if attempt.connector_id is None or attempt.credential_id is None:
@@ -88,13 +80,12 @@ def _get_document_generator(
doc_batch_generator = runnable_connector.poll_source(
start=start_time.timestamp(), end=end_time.timestamp()
)
is_listing_complete = False
else:
# Event types cannot be handled by a background type
raise RuntimeError(f"Invalid task type: {task}")
return doc_batch_generator, is_listing_complete
return doc_batch_generator
def _run_indexing(
@@ -107,7 +98,6 @@ def _run_indexing(
3. Updates Postgres to record the indexed documents + the outcome of this run
"""
start_time = time.time()
db_embedding_model = index_attempt.embedding_model
index_name = db_embedding_model.index_name
@@ -125,6 +115,8 @@ def _run_indexing(
normalize=db_embedding_model.normalize,
query_prefix=db_embedding_model.query_prefix,
passage_prefix=db_embedding_model.passage_prefix,
api_key=db_embedding_model.api_key,
provider_type=db_embedding_model.provider_type,
)
indexing_pipeline = build_indexing_pipeline(
@@ -166,7 +158,7 @@ def _run_indexing(
datetime(1970, 1, 1, tzinfo=timezone.utc),
)
doc_batch_generator, is_listing_complete = _get_document_generator(
doc_batch_generator = _get_document_generator(
db_session=db_session,
attempt=index_attempt,
start_time=window_start,
@@ -224,39 +216,6 @@ def _run_indexing(
docs_removed_from_index=0,
)
if is_listing_complete and not DISABLE_DOCUMENT_CLEANUP:
# clean up all documents from the index that have not been returned from the connector
all_indexed_document_ids = {
d.id
for d in get_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
)
}
doc_ids_to_remove = list(
all_indexed_document_ids - all_connector_doc_ids
)
logger.debug(
f"Cleaning up {len(doc_ids_to_remove)} documents that are not contained in the newest connector state"
)
# delete docs from cc-pair and receive the number of completely deleted docs in return
_delete_connector_credential_pair_batch(
document_ids=doc_ids_to_remove,
connector_id=db_connector.id,
credential_id=db_credential.id,
document_index=document_index,
)
update_docs_indexed(
db_session=db_session,
index_attempt=index_attempt,
total_docs_indexed=document_count,
new_docs_indexed=net_doc_change,
docs_removed_from_index=len(doc_ids_to_remove),
)
run_end_dt = window_end
if is_primary:
update_connector_credential_pair(
@@ -329,6 +288,7 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
db_session=db_session,
index_attempt_id=index_attempt_id,
)
if attempt is None:
raise RuntimeError(f"Unable to find IndexAttempt for ID '{index_attempt_id}'")
@@ -346,11 +306,14 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
return attempt
def run_indexing_entrypoint(index_attempt_id: int) -> None:
def run_indexing_entrypoint(index_attempt_id: int, is_ee: bool = False) -> None:
"""Entrypoint for indexing run when using dask distributed.
Wraps the actual logic in a `try` block so that we can catch any exceptions
and mark the attempt as failed."""
try:
if is_ee:
global_version.set_ee()
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)

View File

@@ -22,6 +22,15 @@ def name_document_set_sync_task(document_set_id: int) -> str:
return f"sync_doc_set_{document_set_id}"
def name_cc_prune_task(
connector_id: int | None = None, credential_id: int | None = None
) -> str:
task_name = f"prune_connector_credential_pair_{connector_id}_{credential_id}"
if not connector_id or not credential_id:
task_name = "prune_connector_credential_pair"
return task_name
T = TypeVar("T", bound=Callable)

View File

@@ -35,6 +35,8 @@ from danswer.db.models import IndexModelStatus
from danswer.db.swap_index import check_index_swap
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import LOG_LEVEL
from shared_configs.configs import MODEL_SERVER_PORT
@@ -307,10 +309,18 @@ def kickoff_indexing_jobs(
if use_secondary_index:
run = secondary_client.submit(
run_indexing_entrypoint, attempt.id, pure=False
run_indexing_entrypoint,
attempt.id,
global_version.get_is_ee_version(),
pure=False,
)
else:
run = client.submit(run_indexing_entrypoint, attempt.id, pure=False)
run = client.submit(
run_indexing_entrypoint,
attempt.id,
global_version.get_is_ee_version(),
pure=False,
)
if run:
secondary_str = "(secondary index) " if use_secondary_index else ""
@@ -333,13 +343,15 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
logger.info("Running a first inference to warm up embedding model")
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=INDEXING_MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
if db_embedding_model.cloud_provider_id is None:
logger.info("Running a first inference to warm up embedding model")
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=INDEXING_MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
client_primary: Client | SimpleJobClient
client_secondary: Client | SimpleJobClient
@@ -398,6 +410,8 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
def update__main() -> None:
set_is_ee_based_on_env_variable()
logger.info("Starting Indexing Loop")
update_loop()

View File

@@ -1,5 +1,4 @@
import re
from collections.abc import Sequence
from typing import cast
from sqlalchemy.orm import Session
@@ -9,42 +8,30 @@ from danswer.chat.models import LlmDoc
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.models import ChatMessage
from danswer.llm.answering.models import PreviousMessage
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.utils.logger import setup_logger
logger = setup_logger()
def llm_doc_from_inference_section(inf_chunk: InferenceSection) -> LlmDoc:
def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDoc:
return LlmDoc(
document_id=inf_chunk.document_id,
document_id=inference_section.center_chunk.document_id,
# This one is using the combined content of all the chunks of the section
# In default settings, this is the same as just the content of base chunk
content=inf_chunk.combined_content,
blurb=inf_chunk.blurb,
semantic_identifier=inf_chunk.semantic_identifier,
source_type=inf_chunk.source_type,
metadata=inf_chunk.metadata,
updated_at=inf_chunk.updated_at,
link=inf_chunk.source_links[0] if inf_chunk.source_links else None,
source_links=inf_chunk.source_links,
content=inference_section.combined_content,
blurb=inference_section.center_chunk.blurb,
semantic_identifier=inference_section.center_chunk.semantic_identifier,
source_type=inference_section.center_chunk.source_type,
metadata=inference_section.center_chunk.metadata,
updated_at=inference_section.center_chunk.updated_at,
link=inference_section.center_chunk.source_links[0]
if inference_section.center_chunk.source_links
else None,
source_links=inference_section.center_chunk.source_links,
)
def map_document_id_order(
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
) -> dict[str, int]:
order_mapping = {}
current = 1 if one_indexed else 0
for chunk in chunks:
if chunk.document_id not in order_mapping:
order_mapping[chunk.document_id] = current
current += 1
return order_mapping
def create_chat_chain(
chat_session_id: int,
db_session: Session,

View File

@@ -1,5 +1,3 @@
from typing import cast
import yaml
from sqlalchemy.orm import Session
@@ -50,7 +48,7 @@ def load_personas_from_yaml(
with Session(get_sqlalchemy_engine()) as db_session:
for persona in all_personas:
doc_set_names = persona["document_sets"]
doc_sets: list[DocumentSetDBModel] | None = [
doc_sets: list[DocumentSetDBModel] = [
get_or_create_document_set_by_name(db_session, name)
for name in doc_set_names
]
@@ -58,22 +56,24 @@ def load_personas_from_yaml(
# Assume if user hasn't set any document sets for the persona, the user may want
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
# the document sets for the persona
if not doc_sets:
doc_sets = None
prompt_set_names = persona["prompts"]
if not prompt_set_names:
prompts: list[PromptDBModel | None] | None = None
doc_set_ids: list[int] | None = None
if doc_sets:
doc_set_ids = [doc_set.id for doc_set in doc_sets]
else:
prompts = [
doc_set_ids = None
prompt_ids: list[int] | None = None
prompt_set_names = persona["prompts"]
if prompt_set_names:
prompts: list[PromptDBModel | None] = [
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
for prompt_name in prompt_set_names
]
if any([prompt is None for prompt in prompts]):
raise ValueError("Invalid Persona configs, not all prompts exist")
if not prompts:
prompts = None
if prompts:
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
p_id = persona.get("id")
upsert_persona(
@@ -91,8 +91,8 @@ def load_personas_from_yaml(
llm_model_provider_override=None,
llm_model_version_override=None,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
prompts=cast(list[PromptDBModel] | None, prompts),
document_sets=doc_sets,
prompt_ids=prompt_ids,
document_set_ids=doc_set_ids,
default_persona=True,
is_public=True,
db_session=db_session,

View File

@@ -42,11 +42,21 @@ class QADocsResponse(RetrievalDocs):
return initial_dict
# Second chunk of info for streaming QA
class LLMRelevanceFilterResponse(BaseModel):
relevant_chunk_indices: list[int]
class RelevanceChunk(BaseModel):
# TODO make this document level. Also slight misnomer here as this is actually
# done at the section level currently rather than the chunk
relevant: bool | None = None
content: str | None = None
class LLMRelevanceSummaryResponse(BaseModel):
relevance_summaries: dict[str, RelevanceChunk]
class DanswerAnswerPiece(BaseModel):
# A small piece of a complete answer. Used for streaming back answers.
answer_piece: str | None # if None, specifies the end of an Answer

View File

@@ -10,14 +10,15 @@ from danswer.chat.models import CitationInfo
from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import ImageGenerationDisplay
from danswer.chat.models import LlmDoc
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.chat import attach_files_to_chat_message
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
@@ -34,6 +35,7 @@ from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
from danswer.db.persona import get_persona_by_id
from danswer.document_index.factory import get_default_document_index
from danswer.file_store.models import ChatFileType
from danswer.file_store.models import FileDescriptor
@@ -46,10 +48,15 @@ from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_llm_for_persona
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.search.enums import OptionalSearchSetting
from danswer.search.retrieval.search_runner import inference_documents_from_ids
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import InferenceSection
from danswer.search.retrieval.search_runner import inference_sections_from_ids
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
from danswer.search.utils import drop_llm_indices
@@ -64,6 +71,14 @@ from danswer.tools.force import ForceUseTool
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
from danswer.tools.images.image_generation_tool import ImageGenerationTool
from danswer.tools.internet_search.internet_search_tool import (
INTERNET_SEARCH_RESPONSE_ID,
)
from danswer.tools.internet_search.internet_search_tool import (
internet_search_response_to_search_docs,
)
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
@@ -141,6 +156,37 @@ def _handle_search_tool_response_summary(
)
def _handle_internet_search_tool_response_summary(
packet: ToolResponse,
db_session: Session,
) -> tuple[QADocsResponse, list[DbSearchDoc]]:
internet_search_response = cast(InternetSearchResponse, packet.response)
server_search_docs = internet_search_response_to_search_docs(
internet_search_response
)
reference_db_search_docs = [
create_db_search_doc(server_search_doc=doc, db_session=db_session)
for doc in server_search_docs
]
response_docs = [
translate_db_search_doc_to_server_search_doc(db_search_doc)
for db_search_doc in reference_db_search_docs
]
return (
QADocsResponse(
rephrased_query=internet_search_response.revised_query,
top_documents=response_docs,
predicted_flow=QueryFlow.QUESTION_ANSWER,
predicted_search=SearchType.HYBRID,
applied_source_filters=[],
applied_time_cutoff=None,
recency_bias_multiplier=1.0,
),
reference_db_search_docs,
)
def _check_should_force_search(
new_msg_req: CreateChatMessageRequest,
) -> ForceUseTool | None:
@@ -168,7 +214,7 @@ def _check_should_force_search(
args = {"query": new_msg_req.message}
return ForceUseTool(
tool_name=SearchTool.NAME,
tool_name=SearchTool._NAME,
args=args,
)
return None
@@ -223,7 +269,15 @@ def stream_chat_message_objects(
parent_id = new_msg_req.parent_message_id
reference_doc_ids = new_msg_req.search_doc_ids
retrieval_options = new_msg_req.retrieval_options
persona = chat_session.persona
alternate_assistant_id = new_msg_req.alternate_assistant_id
# use alternate persona if alternative assistant id is passed in
if alternate_assistant_id is not None:
persona = get_persona_by_id(
alternate_assistant_id, user=user, db_session=db_session
)
else:
persona = chat_session.persona
prompt_id = new_msg_req.prompt_id
if prompt_id is None and persona.prompts:
@@ -235,7 +289,7 @@ def stream_chat_message_objects(
)
try:
llm = get_llm_for_persona(
llm, fast_llm = get_llms_for_persona(
persona=persona,
llm_override=new_msg_req.llm_override or chat_session.llm_override,
additional_headers=litellm_additional_headers,
@@ -328,7 +382,7 @@ def stream_chat_message_objects(
)
selected_db_search_docs = None
selected_llm_docs: list[LlmDoc] | None = None
selected_sections: list[InferenceSection] | None = None
if reference_doc_ids:
identifier_tuples = get_doc_query_identifiers_from_model(
search_doc_ids=reference_doc_ids,
@@ -338,8 +392,8 @@ def stream_chat_message_objects(
)
# Generates full documents currently
# May extend to include chunk ranges
selected_llm_docs = inference_documents_from_ids(
# May extend to use sections instead in the future
selected_sections = inference_sections_from_ids(
doc_identifiers=identifier_tuples,
document_index=document_index,
)
@@ -380,6 +434,7 @@ def stream_chat_message_objects(
# rephrased_query=,
# token_count=,
message_type=MessageType.ASSISTANT,
alternate_assistant_id=new_msg_req.alternate_assistant_id,
# error=,
# reference_docs=,
db_session=db_session,
@@ -389,11 +444,15 @@ def stream_chat_message_objects(
if not final_msg.prompt:
raise RuntimeError("No Prompt found")
prompt_config = PromptConfig.from_model(
final_msg.prompt,
prompt_override=(
new_msg_req.prompt_override or chat_session.prompt_override
),
prompt_config = (
PromptConfig.from_model(
final_msg.prompt,
prompt_override=(
new_msg_req.prompt_override or chat_session.prompt_override
),
)
if not persona
else PromptConfig.from_model(persona.prompts[0])
)
# find out what tools to use
@@ -411,21 +470,22 @@ def stream_chat_message_objects(
retrieval_options=retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
selected_docs=selected_llm_docs,
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
)
tool_dict[db_tool_model.id] = [search_tool]
elif tool_cls.__name__ == ImageGenerationTool.__name__:
dalle_key = None
img_generation_llm_config: LLMConfig | None = None
if (
llm
and llm.config.api_key
and llm.config.model_provider == "openai"
):
dalle_key = llm.config.api_key
img_generation_llm_config = llm.config
else:
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
@@ -442,9 +502,30 @@ def stream_chat_message_objects(
raise ValueError(
"Image generation tool requires an OpenAI API key"
)
dalle_key = openai_provider.api_key
img_generation_llm_config = LLMConfig(
model_provider=openai_provider.provider,
model_name=openai_provider.default_model_name,
temperature=GEN_AI_TEMPERATURE,
api_key=openai_provider.api_key,
api_base=openai_provider.api_base,
api_version=openai_provider.api_version,
)
tool_dict[db_tool_model.id] = [
ImageGenerationTool(api_key=dalle_key)
ImageGenerationTool(
api_key=cast(str, img_generation_llm_config.api_key),
api_base=img_generation_llm_config.api_base,
api_version=img_generation_llm_config.api_version,
additional_headers=litellm_additional_headers,
)
]
elif tool_cls.__name__ == InternetSearchTool.__name__:
bing_api_key = BING_API_KEY
if not bing_api_key:
raise ValueError(
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
)
tool_dict[db_tool_model.id] = [
InternetSearchTool(api_key=bing_api_key)
]
continue
@@ -481,10 +562,14 @@ def stream_chat_message_objects(
prompt_config=prompt_config,
llm=(
llm
or get_llm_for_persona(
persona=persona,
llm_override=new_msg_req.llm_override or chat_session.llm_override,
additional_headers=litellm_additional_headers,
or get_main_llm_from_tuple(
get_llms_for_persona(
persona=persona,
llm_override=(
new_msg_req.llm_override or chat_session.llm_override
),
additional_headers=litellm_additional_headers,
)
)
),
message_history=[
@@ -548,6 +633,15 @@ def stream_chat_message_objects(
yield ImageGenerationDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
(
qa_docs_response,
reference_db_search_docs,
) = _handle_internet_search_tool_response_summary(
packet=packet,
db_session=db_session,
)
yield qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
yield CustomToolResponse(
@@ -589,7 +683,7 @@ def stream_chat_message_objects(
tool_name_to_tool_id: dict[str, int] = {}
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name()] = tool_id
tool_name_to_tool_id[tool.name] = tool_id
gen_ai_response_message = partial_response(
message=answer.llm_answer,

View File

@@ -4,6 +4,7 @@ import urllib.parse
from danswer.configs.constants import AuthType
from danswer.configs.constants import DocumentIndexType
from danswer.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
#####
# App Configs
@@ -45,13 +46,14 @@ DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
# Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive
# information. This provides an extra layer of security on top of Postgres access controls
# and is available in Danswer EE
ENCRYPTION_KEY_SECRET = os.environ.get("ENCRYPTION_KEY_SECRET")
ENCRYPTION_KEY_SECRET = os.environ.get("ENCRYPTION_KEY_SECRET") or ""
# Turn off mask if admin users should see full credentials for data connectors.
MASK_CREDENTIAL_PREFIX = (
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
)
SESSION_EXPIRE_TIME_SECONDS = int(
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
) # 7 days
@@ -160,6 +162,11 @@ WEB_CONNECTOR_OAUTH_CLIENT_SECRET = os.environ.get("WEB_CONNECTOR_OAUTH_CLIENT_S
WEB_CONNECTOR_OAUTH_TOKEN_URL = os.environ.get("WEB_CONNECTOR_OAUTH_TOKEN_URL")
WEB_CONNECTOR_VALIDATE_URLS = os.environ.get("WEB_CONNECTOR_VALIDATE_URLS")
HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY = os.environ.get(
"HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY",
HtmlBasedConnectorTransformLinksStrategy.STRIP,
)
NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP = (
os.environ.get("NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP", "").lower()
== "true"
@@ -178,6 +185,12 @@ CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES = (
os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES", "").lower() == "true"
)
# Save pages labels as Danswer metadata tags
# The reason to skip this would be to reduce the number of calls to Confluence due to rate limit concerns
CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING = (
os.environ.get("CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING", "").lower() == "true"
)
JIRA_CONNECTOR_LABELS_TO_SKIP = [
ignored_tag
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
@@ -199,6 +212,22 @@ EXPERIMENTAL_CHECKPOINTING_ENABLED = (
os.environ.get("EXPERIMENTAL_CHECKPOINTING_ENABLED", "").lower() == "true"
)
DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day
PREVENT_SIMULTANEOUS_PRUNING = (
os.environ.get("PREVENT_SIMULTANEOUS_PRUNING", "").lower() == "true"
)
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE = int(
os.environ.get("MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE", 0)
)
# comma delimited list of zendesk article labels to skip indexing for
ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS = os.environ.get(
"ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS", ""
).split(",")
#####
# Indexing Configs
@@ -219,19 +248,17 @@ DISABLE_INDEX_UPDATE_ON_SWAP = (
# fairly large amount of memory in order to increase substantially, since
# each worker loads the embedding models into memory.
NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1)
CHUNK_OVERLAP = 0
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
ENABLE_MINI_CHUNK = os.environ.get("ENABLE_MINI_CHUNK", "").lower() == "true"
# Finer grained chunking for more detail retention
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end
MINI_CHUNK_SIZE = 150
# Include the document level metadata in each chunk. If the metadata is too long, then it is thrown out
# We don't want the metadata to overwhelm the actual contents of the chunk
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"
# Timeout to wait for job's last update before killing it, in hours
CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 3))
# If set to true, then will not clean up documents that "no longer exist" when running Load connectors
DISABLE_DOCUMENT_CLEANUP = (
os.environ.get("DISABLE_DOCUMENT_CLEANUP", "").lower() == "true"
)
#####
@@ -246,10 +273,14 @@ JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
CURRENT_PROCESS_IS_AN_INDEXING_JOB = (
os.environ.get("CURRENT_PROCESS_IS_AN_INDEXING_JOB", "").lower() == "true"
)
# Logs every model prompt and output, mostly used for development or exploration purposes
# Sets LiteLLM to verbose logging
LOG_ALL_MODEL_INTERACTIONS = (
os.environ.get("LOG_ALL_MODEL_INTERACTIONS", "").lower() == "true"
)
# Logs Danswer only model interactions like prompts, responses, messages etc.
LOG_DANSWER_MODEL_INTERACTIONS = (
os.environ.get("LOG_DANSWER_MODEL_INTERACTIONS", "").lower() == "true"
)
# If set to `true` will enable additional logs about Vespa query performance
# (time spent on finding the right docs + time spent fetching summaries from disk)
LOG_VESPA_TIMING_INFORMATION = (
@@ -268,3 +299,15 @@ TOKEN_BUDGET_GLOBALLY_ENABLED = (
CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
os.environ.get("CUSTOM_ANSWER_VALIDITY_CONDITIONS", "[]")
)
#####
# Enterprise Edition Configs
#####
# NOTE: this should only be enabled if you have purchased an enterprise license.
# if you're interested in an enterprise license, please reach out to us at
# founders@danswer.ai OR message Chris Weaver or Yuhong Sun in the Danswer
# Slack community (https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ)
ENTERPRISE_EDITION_ENABLED = (
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
)

View File

@@ -5,7 +5,10 @@ PROMPTS_YAML = "./danswer/chat/prompts.yaml"
PERSONAS_YAML = "./danswer/chat/personas.yaml"
NUM_RETURNED_HITS = 50
NUM_RERANKED_RESULTS = 15
# Used for LLM filtering and reranking
# We want this to be approximately the number of results we want to show on the first page
# It cannot be too large due to cost and latency implications
NUM_RERANKED_RESULTS = 20
# May be less depending on model
MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
@@ -25,9 +28,10 @@ BASE_RECENCY_DECAY = 0.5
FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
# Currently this next one is not configurable via env
DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
DISABLE_LLM_FILTER_EXTRACTION = (
os.environ.get("DISABLE_LLM_FILTER_EXTRACTION", "").lower() == "true"
)
# For the highest matching base size chunk, how many chunks above and below do we pull in by default
# Note this is not in any of the deployment configs yet
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 0)
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 0)
# Whether the LLM should evaluate all of the document chunks passed in for usefulness
# in relation to the user query
DISABLE_LLM_CHUNK_FILTER = (
@@ -43,8 +47,6 @@ DISABLE_LLM_QUERY_REPHRASE = (
# 1 edit per 20 characters, currently unused due to fuzzy match being too slow
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
# Include additional document/chunk metadata in prompt to GenerativeAI
INCLUDE_METADATA = False
# Keyword Search Drop Stopwords
# If user has changed the default model, would most likely be to use a multilingual
# model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords
@@ -64,9 +66,31 @@ TITLE_CONTENT_RATIO = max(
# A list of languages passed to the LLM to rephase the query
# For example "English,French,Spanish", be sure to use the "," separator
MULTILINGUAL_QUERY_EXPANSION = os.environ.get("MULTILINGUAL_QUERY_EXPANSION") or None
LANGUAGE_HINT = "\n" + (
os.environ.get("LANGUAGE_HINT")
or "IMPORTANT: Respond in the same language as my query!"
)
LANGUAGE_CHAT_NAMING_HINT = (
os.environ.get("LANGUAGE_CHAT_NAMING_HINT")
or "The name of the conversation must be in the same language as the user query."
)
# Agentic search takes significantly more tokens and therefore has much higher cost.
# This configuration allows users to get a search-only experience with instant results
# and no involvement from the LLM.
# Additionally, some LLM providers have strict rate limits which may prohibit
# sending many API requests at once (as is done in agentic search).
DISABLE_AGENTIC_SEARCH = (
os.environ.get("DISABLE_AGENTIC_SEARCH") or "false"
).lower() == "true"
# Stops streaming answers back to the UI if this pattern is seen:
STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None
# The backend logic for this being True isn't fully supported yet
HARD_DELETE_CHATS = False
# Internet Search
BING_API_KEY = os.environ.get("BING_API_KEY") or None

View File

@@ -19,6 +19,7 @@ DOCUMENT_SETS = "document_sets"
TIME_FILTER = "time_filter"
METADATA = "metadata"
METADATA_LIST = "metadata_list"
METADATA_SUFFIX = "metadata_suffix"
MATCH_HIGHLIGHTS = "match_highlights"
# stored in the `metadata` of a chunk. Used to signify that this chunk should
# not be used for QA. For example, Google Drive file types which can't be parsed
@@ -41,17 +42,16 @@ DEFAULT_BOOST = 0
SESSION_KEY = "session"
QUERY_EVENT_ID = "query_event_id"
LLM_CHUNKS = "llm_chunks"
TOKEN_BUDGET = "token_budget"
TOKEN_BUDGET_TIME_PERIOD = "token_budget_time_period"
ENABLE_TOKEN_BUDGET = "enable_token_budget"
TOKEN_BUDGET_SETTINGS = "token_budget_settings"
# For chunking/processing chunks
TITLE_SEPARATOR = "\n\r\n"
MAX_CHUNK_TITLE_LEN = 1000
RETURN_SEPARATOR = "\n\r\n"
SECTION_SEPARATOR = "\n\n"
# For combining attributes, doesn't have to be unique/perfect to work
INDEX_SEPARATOR = "==="
# For File Connector Metadata override file
DANSWER_METADATA_FILENAME = ".danswer_metadata.json"
# Messages
DISABLED_GEN_AI_MSG = (
@@ -102,6 +102,21 @@ class DocumentSource(str, Enum):
CLICKUP = "clickup"
MEDIAWIKI = "mediawiki"
WIKIPEDIA = "wikipedia"
S3 = "s3"
R2 = "r2"
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
OCI_STORAGE = "oci_storage"
NOT_APPLICABLE = "not_applicable"
class BlobType(str, Enum):
R2 = "r2"
S3 = "s3"
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
OCI_STORAGE = "oci_storage"
# Special case, for internet search
NOT_APPLICABLE = "not_applicable"
class DocumentIndexType(str, Enum):
@@ -117,6 +132,11 @@ class AuthType(str, Enum):
SAML = "saml"
class QAFeedbackType(str, Enum):
LIKE = "like" # User likes the answer, used for metrics
DISLIKE = "dislike" # User dislikes the answer, used for metrics
class SearchFeedbackType(str, Enum):
ENDORSE = "endorse" # boost this document for all future queries
REJECT = "reject" # down-boost this document for all future queries
@@ -142,4 +162,5 @@ class FileOrigin(str, Enum):
CHAT_UPLOAD = "chat_upload"
CHAT_IMAGE_GEN = "chat_image_gen"
CONNECTOR = "connector"
GENERATED_REPORT = "generated_report"
OTHER = "other"

View File

@@ -47,10 +47,6 @@ DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
DANSWER_BOT_RESPOND_EVERY_CHANNEL = (
os.environ.get("DANSWER_BOT_RESPOND_EVERY_CHANNEL", "").lower() == "true"
)
# Auto detect query options like time cutoff or heavily favor recently updated docs
DISABLE_DANSWER_BOT_FILTER_DETECT = (
os.environ.get("DISABLE_DANSWER_BOT_FILTER_DETECT", "").lower() == "true"
)
# Add a second LLM call post Answer to verify if the Answer is valid
# Throws out answers that don't directly or fully answer the user query
# This is the default for all DanswerBot channels unless the channel is configured individually

View File

@@ -39,8 +39,8 @@ ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "passage: ")
# Purely an optimization, memory limitation consideration
BATCH_SIZE_ENCODE_CHUNKS = 8
# For score display purposes, only way is to know the expected ranges
CROSS_ENCODER_RANGE_MAX = 12
CROSS_ENCODER_RANGE_MIN = -12
CROSS_ENCODER_RANGE_MAX = 1
CROSS_ENCODER_RANGE_MIN = 0
# Unused currently, can't be used with the current default encoder model due to its output range
SEARCH_DISTANCE_CUTOFF = 0

View File

@@ -0,0 +1,277 @@
import os
from datetime import datetime
from datetime import timezone
from io import BytesIO
from typing import Any
from typing import Optional
import boto3
from botocore.client import Config
from mypy_boto3_s3 import S3Client
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import BlobType
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import extract_file_text
from danswer.utils.logger import setup_logger
logger = setup_logger()
class BlobStorageConnector(LoadConnector, PollConnector):
def __init__(
self,
bucket_type: str,
bucket_name: str,
prefix: str = "",
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.bucket_type: BlobType = BlobType(bucket_type)
self.bucket_name = bucket_name
self.prefix = prefix if not prefix or prefix.endswith("/") else prefix + "/"
self.batch_size = batch_size
self.s3_client: Optional[S3Client] = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Checks for boto3 credentials based on the bucket type.
(1) R2: Access Key ID, Secret Access Key, Account ID
(2) S3: AWS Access Key ID, AWS Secret Access Key
(3) GOOGLE_CLOUD_STORAGE: Access Key ID, Secret Access Key, Project ID
(4) OCI_STORAGE: Namespace, Region, Access Key ID, Secret Access Key
For each bucket type, the method initializes the appropriate S3 client:
- R2: Uses Cloudflare R2 endpoint with S3v4 signature
- S3: Creates a standard boto3 S3 client
- GOOGLE_CLOUD_STORAGE: Uses Google Cloud Storage endpoint
- OCI_STORAGE: Uses Oracle Cloud Infrastructure Object Storage endpoint
Raises ConnectorMissingCredentialError if required credentials are missing.
Raises ValueError for unsupported bucket types.
"""
logger.info(
f"Loading credentials for {self.bucket_name} or type {self.bucket_type}"
)
if self.bucket_type == BlobType.R2:
if not all(
credentials.get(key)
for key in ["r2_access_key_id", "r2_secret_access_key", "account_id"]
):
raise ConnectorMissingCredentialError("Cloudflare R2")
self.s3_client = boto3.client(
"s3",
endpoint_url=f"https://{credentials['account_id']}.r2.cloudflarestorage.com",
aws_access_key_id=credentials["r2_access_key_id"],
aws_secret_access_key=credentials["r2_secret_access_key"],
region_name="auto",
config=Config(signature_version="s3v4"),
)
elif self.bucket_type == BlobType.S3:
if not all(
credentials.get(key)
for key in ["aws_access_key_id", "aws_secret_access_key"]
):
raise ConnectorMissingCredentialError("Google Cloud Storage")
session = boto3.Session(
aws_access_key_id=credentials["aws_access_key_id"],
aws_secret_access_key=credentials["aws_secret_access_key"],
)
self.s3_client = session.client("s3")
elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE:
if not all(
credentials.get(key) for key in ["access_key_id", "secret_access_key"]
):
raise ConnectorMissingCredentialError("Google Cloud Storage")
self.s3_client = boto3.client(
"s3",
endpoint_url="https://storage.googleapis.com",
aws_access_key_id=credentials["access_key_id"],
aws_secret_access_key=credentials["secret_access_key"],
region_name="auto",
)
elif self.bucket_type == BlobType.OCI_STORAGE:
if not all(
credentials.get(key)
for key in ["namespace", "region", "access_key_id", "secret_access_key"]
):
raise ConnectorMissingCredentialError("Oracle Cloud Infrastructure")
self.s3_client = boto3.client(
"s3",
endpoint_url=f"https://{credentials['namespace']}.compat.objectstorage.{credentials['region']}.oraclecloud.com",
aws_access_key_id=credentials["access_key_id"],
aws_secret_access_key=credentials["secret_access_key"],
region_name=credentials["region"],
)
else:
raise ValueError(f"Unsupported bucket type: {self.bucket_type}")
return None
def _download_object(self, key: str) -> bytes:
if self.s3_client is None:
raise ConnectorMissingCredentialError("Blob storage")
object = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
return object["Body"].read()
# NOTE: Left in as may be useful for one-off access to documents and sharing across orgs.
# def _get_presigned_url(self, key: str) -> str:
# if self.s3_client is None:
# raise ConnectorMissingCredentialError("Blog storage")
# url = self.s3_client.generate_presigned_url(
# "get_object",
# Params={"Bucket": self.bucket_name, "Key": key},
# ExpiresIn=self.presign_length,
# )
# return url
def _get_blob_link(self, key: str) -> str:
if self.s3_client is None:
raise ConnectorMissingCredentialError("Blob storage")
if self.bucket_type == BlobType.R2:
account_id = self.s3_client.meta.endpoint_url.split("//")[1].split(".")[0]
return f"https://{account_id}.r2.cloudflarestorage.com/{self.bucket_name}/{key}"
elif self.bucket_type == BlobType.S3:
region = self.s3_client.meta.region_name
return f"https://{self.bucket_name}.s3.{region}.amazonaws.com/{key}"
elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE:
return f"https://storage.cloud.google.com/{self.bucket_name}/{key}"
elif self.bucket_type == BlobType.OCI_STORAGE:
namespace = self.s3_client.meta.endpoint_url.split("//")[1].split(".")[0]
region = self.s3_client.meta.region_name
return f"https://objectstorage.{region}.oraclecloud.com/n/{namespace}/b/{self.bucket_name}/o/{key}"
else:
raise ValueError(f"Unsupported bucket type: {self.bucket_type}")
def _yield_blob_objects(
self,
start: datetime,
end: datetime,
) -> GenerateDocumentsOutput:
if self.s3_client is None:
raise ConnectorMissingCredentialError("Blog storage")
paginator = self.s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
batch: list[Document] = []
for page in pages:
if "Contents" not in page:
continue
for obj in page["Contents"]:
if obj["Key"].endswith("/"):
continue
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
if not start <= last_modified <= end:
continue
downloaded_file = self._download_object(obj["Key"])
link = self._get_blob_link(obj["Key"])
name = os.path.basename(obj["Key"])
try:
text = extract_file_text(
name,
BytesIO(downloaded_file),
break_on_unprocessable=False,
)
batch.append(
Document(
id=f"{self.bucket_type}:{self.bucket_name}:{obj['Key']}",
sections=[Section(link=link, text=text)],
source=DocumentSource(self.bucket_type.value),
semantic_identifier=name,
doc_updated_at=last_modified,
metadata={},
)
)
if len(batch) == self.batch_size:
yield batch
batch = []
except Exception as e:
logger.exception(
f"Error decoding object {obj['Key']} as UTF-8: {e}"
)
if batch:
yield batch
def load_from_state(self) -> GenerateDocumentsOutput:
logger.info("Loading blob objects")
return self._yield_blob_objects(
start=datetime(1970, 1, 1, tzinfo=timezone.utc),
end=datetime.now(timezone.utc),
)
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
if self.s3_client is None:
raise ConnectorMissingCredentialError("Blog storage")
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
for batch in self._yield_blob_objects(start_datetime, end_datetime):
yield batch
return None
if __name__ == "__main__":
credentials_dict = {
"aws_access_key_id": os.environ.get("AWS_ACCESS_KEY_ID"),
"aws_secret_access_key": os.environ.get("AWS_SECRET_ACCESS_KEY"),
}
# Initialize the connector
connector = BlobStorageConnector(
bucket_type=os.environ.get("BUCKET_TYPE") or "s3",
bucket_name=os.environ.get("BUCKET_NAME") or "test",
prefix="",
)
try:
connector.load_credentials(credentials_dict)
document_batch_generator = connector.load_from_state()
for document_batch in document_batch_generator:
print("First batch of documents:")
for doc in document_batch:
print(f"Document ID: {doc.id}")
print(f"Semantic Identifier: {doc.semantic_identifier}")
print(f"Source: {doc.source}")
print(f"Updated At: {doc.doc_updated_at}")
print("Sections:")
for section in doc.sections:
print(f" - Link: {section.link}")
print(f" - Text: {section.text[:100]}...")
print("---")
break
except ConnectorMissingCredentialError as e:
print(f"Error: {e}")
except Exception as e:
print(f"An unexpected error occurred: {e}")

View File

@@ -15,6 +15,7 @@ from requests import HTTPError
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
@@ -36,16 +37,18 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
# Potential Improvements
# 1. If wiki page instead of space, do a search of all the children of the page instead of index all in the space
# 2. Include attachments, etc
# 3. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost
# 1. Include attachments, etc
# 2. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str]:
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]:
"""Sample
https://danswer.atlassian.net/wiki/spaces/1234abcd/overview
URL w/ page: https://danswer.atlassian.net/wiki/spaces/1234abcd/pages/5678efgh/overview
URL w/o page: https://danswer.atlassian.net/wiki/spaces/ASAM/overview
wiki_base is https://danswer.atlassian.net/wiki
space is 1234abcd
page_id is 5678efgh
"""
parsed_url = urlparse(wiki_url)
wiki_base = (
@@ -54,18 +57,25 @@ def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str]:
+ parsed_url.netloc
+ parsed_url.path.split("/spaces")[0]
)
space = parsed_url.path.split("/")[3]
return wiki_base, space
path_parts = parsed_url.path.split("/")
space = path_parts[3]
page_id = path_parts[5] if len(path_parts) > 5 else ""
return wiki_base, space, page_id
def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, str]:
def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, str, str]:
"""Sample
https://danswer.ai/confluence/display/1234abcd/overview
URL w/ page https://danswer.ai/confluence/display/1234abcd/pages/5678efgh/overview
URL w/o page https://danswer.ai/confluence/display/1234abcd/overview
wiki_base is https://danswer.ai/confluence
space is 1234abcd
page_id is 5678efgh
"""
# /display/ is always right before the space and at the end of the base url
# /display/ is always right before the space and at the end of the base print()
DISPLAY = "/display/"
PAGE = "/pages/"
parsed_url = urlparse(wiki_url)
wiki_base = (
@@ -75,10 +85,13 @@ def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, st
+ parsed_url.path.split(DISPLAY)[0]
)
space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0]
return wiki_base, space
page_id = ""
if (content := parsed_url.path.split(PAGE)) and len(content) > 1:
page_id = content[1]
return wiki_base, space, page_id
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, bool]:
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]:
is_confluence_cloud = (
".atlassian.net/wiki/spaces/" in wiki_url
or ".jira.com/wiki/spaces/" in wiki_url
@@ -86,15 +99,19 @@ def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, bool]:
try:
if is_confluence_cloud:
wiki_base, space = _extract_confluence_keys_from_cloud_url(wiki_url)
wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url(
wiki_url
)
else:
wiki_base, space = _extract_confluence_keys_from_datacenter_url(wiki_url)
wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url(
wiki_url
)
except Exception as e:
error_msg = f"Not a valid Confluence Wiki Link, unable to extract wiki base and space names. Exception: {e}"
error_msg = f"Not a valid Confluence Wiki Link, unable to extract wiki base, space, and page id. Exception: {e}"
logger.error(error_msg)
raise ValueError(error_msg)
return wiki_base, space, is_confluence_cloud
return wiki_base, space, page_id, is_confluence_cloud
@lru_cache()
@@ -195,10 +212,135 @@ def _comment_dfs(
return comments_str
class RecursiveIndexer:
def __init__(
self,
batch_size: int,
confluence_client: Confluence,
index_origin: bool,
origin_page_id: str,
) -> None:
self.batch_size = 1
# batch_size
self.confluence_client = confluence_client
self.index_origin = index_origin
self.origin_page_id = origin_page_id
self.pages = self.recurse_children_pages(0, self.origin_page_id)
def get_pages(self, ind: int, size: int) -> list[dict]:
if ind * size > len(self.pages):
return []
return self.pages[ind * size : (ind + 1) * size]
def _fetch_origin_page(
self,
) -> dict[str, Any]:
get_page_by_id = make_confluence_call_handle_rate_limit(
self.confluence_client.get_page_by_id
)
try:
origin_page = get_page_by_id(
self.origin_page_id, expand="body.storage.value,version"
)
return origin_page
except Exception as e:
logger.warning(
f"Appending orgin page with id {self.origin_page_id} failed: {e}"
)
return {}
def recurse_children_pages(
self,
start_ind: int,
page_id: str,
) -> list[dict[str, Any]]:
pages: list[dict[str, Any]] = []
current_level_pages: list[dict[str, Any]] = []
next_level_pages: list[dict[str, Any]] = []
# Initial fetch of first level children
index = start_ind
while batch := self._fetch_single_depth_child_pages(
index, self.batch_size, page_id
):
current_level_pages.extend(batch)
index += len(batch)
pages.extend(current_level_pages)
# Recursively index children and children's children, etc.
while current_level_pages:
for child in current_level_pages:
child_index = 0
while child_batch := self._fetch_single_depth_child_pages(
child_index, self.batch_size, child["id"]
):
next_level_pages.extend(child_batch)
child_index += len(child_batch)
pages.extend(next_level_pages)
current_level_pages = next_level_pages
next_level_pages = []
if self.index_origin:
try:
origin_page = self._fetch_origin_page()
pages.append(origin_page)
except Exception as e:
logger.warning(f"Appending origin page with id {page_id} failed: {e}")
return pages
def _fetch_single_depth_child_pages(
self, start_ind: int, batch_size: int, page_id: str
) -> list[dict[str, Any]]:
child_pages: list[dict[str, Any]] = []
get_page_child_by_type = make_confluence_call_handle_rate_limit(
self.confluence_client.get_page_child_by_type
)
try:
child_page = get_page_child_by_type(
page_id,
type="page",
start=start_ind,
limit=batch_size,
expand="body.storage.value,version",
)
child_pages.extend(child_page)
return child_pages
except Exception:
logger.warning(
f"Batch failed with page {page_id} at offset {start_ind} "
f"with size {batch_size}, processing pages individually..."
)
for i in range(batch_size):
ind = start_ind + i
try:
child_page = get_page_child_by_type(
page_id,
type="page",
start=ind,
limit=1,
expand="body.storage.value,version",
)
child_pages.extend(child_page)
except Exception as e:
logger.warning(f"Page {page_id} at offset {ind} failed: {e}")
raise e
return child_pages
class ConfluenceConnector(LoadConnector, PollConnector):
def __init__(
self,
wiki_page_url: str,
index_origin: bool = True,
batch_size: int = INDEX_BATCH_SIZE,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
# if a page has one of the labels specified in this list, we will just
@@ -209,11 +351,27 @@ class ConfluenceConnector(LoadConnector, PollConnector):
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
self.labels_to_skip = set(labels_to_skip)
self.wiki_base, self.space, self.is_cloud = extract_confluence_keys_from_url(
wiki_page_url
)
self.recursive_indexer: RecursiveIndexer | None = None
self.index_origin = index_origin
(
self.wiki_base,
self.space,
self.page_id,
self.is_cloud,
) = extract_confluence_keys_from_url(wiki_page_url)
self.space_level_scan = False
self.confluence_client: Confluence | None = None
if self.page_id is None or self.page_id == "":
self.space_level_scan = True
logger.info(
f"wiki_base: {self.wiki_base}, space: {self.space}, page_id: {self.page_id},"
+ f" space_level_scan: {self.space_level_scan}, origin: {self.index_origin}"
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
username = credentials["confluence_username"]
access_token = credentials["confluence_access_token"]
@@ -231,8 +389,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
self,
confluence_client: Confluence,
start_ind: int,
) -> Collection[dict[str, Any]]:
def _fetch(start_ind: int, batch_size: int) -> Collection[dict[str, Any]]:
) -> list[dict[str, Any]]:
def _fetch_space(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
get_all_pages_from_space = make_confluence_call_handle_rate_limit(
confluence_client.get_all_pages_from_space
)
@@ -241,9 +399,11 @@ class ConfluenceConnector(LoadConnector, PollConnector):
self.space,
start=start_ind,
limit=batch_size,
status="current"
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
else None,
status=(
"current"
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
else None
),
expand="body.storage.value,version",
)
except Exception:
@@ -262,9 +422,11 @@ class ConfluenceConnector(LoadConnector, PollConnector):
self.space,
start=start_ind + i,
limit=1,
status="current"
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
else None,
status=(
"current"
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
else None
),
expand="body.storage.value,version",
)
)
@@ -285,17 +447,41 @@ class ConfluenceConnector(LoadConnector, PollConnector):
return view_pages
def _fetch_page(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
if self.recursive_indexer is None:
self.recursive_indexer = RecursiveIndexer(
origin_page_id=self.page_id,
batch_size=self.batch_size,
confluence_client=self.confluence_client,
index_origin=self.index_origin,
)
return self.recursive_indexer.get_pages(start_ind, batch_size)
pages: list[dict[str, Any]] = []
try:
return _fetch(start_ind, self.batch_size)
pages = (
_fetch_space(start_ind, self.batch_size)
if self.space_level_scan
else _fetch_page(start_ind, self.batch_size)
)
return pages
except Exception as e:
if not self.continue_on_failure:
raise e
# error checking phase, only reachable if `self.continue_on_failure=True`
pages: list[dict[str, Any]] = []
for i in range(self.batch_size):
try:
pages.extend(_fetch(start_ind + i, 1))
pages = (
_fetch_space(start_ind, self.batch_size)
if self.space_level_scan
else _fetch_page(start_ind, self.batch_size)
)
return pages
except Exception:
logger.exception(
"Ran into exception when fetching pages from Confluence"
@@ -307,6 +493,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
get_page_child_by_type = make_confluence_call_handle_rate_limit(
confluence_client.get_page_child_by_type
)
try:
comment_pages = cast(
Collection[dict[str, Any]],
@@ -355,7 +542,14 @@ class ConfluenceConnector(LoadConnector, PollConnector):
page_id, start=0, limit=500
)
for attachment in attachments_container["results"]:
if attachment["metadata"]["mediaType"] in ["image/jpeg", "image/png"]:
if attachment["metadata"]["mediaType"] in [
"image/jpeg",
"image/png",
"image/gif",
"image/svg+xml",
"video/mp4",
"video/quicktime",
]:
continue
if attachment["title"] not in files_in_used:
@@ -366,7 +560,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
if response.status_code == 200:
extract = extract_file_text(
attachment["title"], io.BytesIO(response.content)
attachment["title"], io.BytesIO(response.content), False
)
files_attachment_content.append(extract)
@@ -386,8 +580,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
batch = self._fetch_pages(self.confluence_client, start_ind)
for page in batch:
last_modified_str = page["version"]["when"]
author = cast(str | None, page["version"].get("by", {}).get("email"))
@@ -403,15 +597,18 @@ class ConfluenceConnector(LoadConnector, PollConnector):
if time_filter is None or time_filter(last_modified):
page_id = page["id"]
if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
page_labels = self._fetch_labels(self.confluence_client, page_id)
# check disallowed labels
if self.labels_to_skip:
page_labels = self._fetch_labels(self.confluence_client, page_id)
label_intersection = self.labels_to_skip.intersection(page_labels)
if label_intersection:
logger.info(
f"Page with ID '{page_id}' has a label which has been "
f"designated as disallowed: {label_intersection}. Skipping."
)
continue
page_html = (
@@ -432,6 +629,11 @@ class ConfluenceConnector(LoadConnector, PollConnector):
page_text += attachment_text
comments_text = self._fetch_comments(self.confluence_client, page_id)
page_text += comments_text
doc_metadata: dict[str, str | list[str]] = {
"Wiki Space Name": self.space
}
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING and page_labels:
doc_metadata["labels"] = page_labels
doc_batch.append(
Document(
@@ -440,12 +642,10 @@ class ConfluenceConnector(LoadConnector, PollConnector):
source=DocumentSource.CONFLUENCE,
semantic_identifier=page["title"],
doc_updated_at=last_modified,
primary_owners=[BasicExpertInfo(email=author)]
if author
else None,
metadata={
"Wiki Space Name": self.space,
},
primary_owners=(
[BasicExpertInfo(email=author)] if author else None
),
metadata=doc_metadata,
)
)
return doc_batch, len(batch)

View File

@@ -1,10 +1,14 @@
import time
from collections.abc import Callable
from typing import Any
from typing import cast
from typing import TypeVar
from requests import HTTPError
from retry import retry
from danswer.utils.logger import setup_logger
logger = setup_logger()
F = TypeVar("F", bound=Callable[..., Any])
@@ -18,23 +22,38 @@ class ConfluenceRateLimitError(Exception):
def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
@retry(
exceptions=ConfluenceRateLimitError,
tries=10,
delay=1,
max_delay=600, # 10 minutes
backoff=2,
jitter=1,
)
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
try:
return confluence_call(*args, **kwargs)
except HTTPError as e:
if (
e.response.status_code == 429
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
):
raise ConfluenceRateLimitError()
raise
starting_delay = 5
backoff = 2
max_delay = 600
for attempt in range(10):
try:
return confluence_call(*args, **kwargs)
except HTTPError as e:
if (
e.response.status_code == 429
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
):
retry_after = None
try:
retry_after = int(e.response.headers.get("Retry-After"))
except (ValueError, TypeError):
pass
if retry_after:
logger.warning(
f"Rate limit hit. Retrying after {retry_after} seconds..."
)
time.sleep(retry_after)
else:
logger.warning(
"Rate limit hit. Retrying with exponential backoff..."
)
delay = min(starting_delay * (backoff**attempt), max_delay)
time.sleep(delay)
else:
# re-raise, let caller handle
raise
return cast(F, wrapped_call)

View File

@@ -6,6 +6,7 @@ from typing import TypeVar
from dateutil.parser import parse
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.connectors.models import BasicExpertInfo
from danswer.utils.text_processing import is_valid_email
@@ -57,3 +58,7 @@ def process_in_batches(
) -> Iterator[list[U]]:
for i in range(0, len(objects), batch_size):
yield [process_function(obj) for obj in objects[i : i + batch_size]]
def get_metadata_keys_to_ignore() -> list[str]:
return [IGNORE_FOR_QA]

View File

@@ -11,6 +11,9 @@ from requests import Response
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import PollConnector
@@ -58,63 +61,36 @@ class DiscourseConnector(PollConnector):
self.category_id_map: dict[int, str] = {}
self.batch_size = batch_size
self.permissions: DiscoursePerms | None = None
self.active_categories: set | None = None
@rate_limit_builder(max_calls=100, period=60)
def _make_request(self, endpoint: str, params: dict | None = None) -> Response:
if not self.permissions:
raise ConnectorMissingCredentialError("Discourse")
return discourse_request(endpoint, self.permissions, params)
def _get_categories_map(
self,
) -> None:
assert self.permissions is not None
categories_endpoint = urllib.parse.urljoin(self.base_url, "categories.json")
response = discourse_request(
response = self._make_request(
endpoint=categories_endpoint,
perms=self.permissions,
params={"include_subcategories": True},
)
categories = response.json()["category_list"]["categories"]
self.category_id_map = {
category["id"]: category["name"]
for category in categories
if not self.categories or category["name"].lower() in self.categories
cat["id"]: cat["name"]
for cat in categories
if not self.categories or cat["name"].lower() in self.categories
}
def _get_latest_topics(
self, start: datetime | None, end: datetime | None
) -> list[int]:
assert self.permissions is not None
topic_ids = []
valid_categories = set(self.category_id_map.keys())
latest_endpoint = urllib.parse.urljoin(self.base_url, "latest.json")
response = discourse_request(endpoint=latest_endpoint, perms=self.permissions)
topics = response.json()["topic_list"]["topics"]
for topic in topics:
last_time = topic.get("last_posted_at")
if not last_time:
continue
last_time_dt = time_str_to_utc(last_time)
if start and start > last_time_dt:
continue
if end and end < last_time_dt:
continue
if valid_categories and topic.get("category_id") not in valid_categories:
continue
topic_ids.append(topic["id"])
return topic_ids
self.active_categories = set(self.category_id_map)
def _get_doc_from_topic(self, topic_id: int) -> Document:
assert self.permissions is not None
topic_endpoint = urllib.parse.urljoin(self.base_url, f"t/{topic_id}.json")
response = discourse_request(
endpoint=topic_endpoint,
perms=self.permissions,
)
response = self._make_request(endpoint=topic_endpoint)
topic = response.json()
topic_url = urllib.parse.urljoin(self.base_url, f"t/{topic['slug']}")
@@ -138,10 +114,16 @@ class DiscourseConnector(PollConnector):
sections.append(
Section(link=topic_url, text=parse_html_page_basic(post["cooked"]))
)
category_name = self.category_id_map.get(topic["category_id"])
metadata: dict[str, str | list[str]] = (
{
"category": category_name,
}
if category_name
else {}
)
metadata: dict[str, str | list[str]] = {
"category": self.category_id_map[topic["category_id"]],
}
if topic.get("tags"):
metadata["tags"] = topic["tags"]
@@ -157,26 +139,78 @@ class DiscourseConnector(PollConnector):
)
return doc
def _get_latest_topics(
self, start: datetime | None, end: datetime | None, page: int
) -> list[int]:
assert self.permissions is not None
topic_ids = []
if not self.categories:
latest_endpoint = urllib.parse.urljoin(
self.base_url, f"latest.json?page={page}"
)
response = self._make_request(endpoint=latest_endpoint)
topics = response.json()["topic_list"]["topics"]
else:
topics = []
empty_categories = []
for category_id in self.category_id_map.keys():
category_endpoint = urllib.parse.urljoin(
self.base_url, f"c/{category_id}.json?page={page}&sys=latest"
)
response = self._make_request(endpoint=category_endpoint)
new_topics = response.json()["topic_list"]["topics"]
if len(new_topics) == 0:
empty_categories.append(category_id)
topics.extend(new_topics)
for empty_category in empty_categories:
self.category_id_map.pop(empty_category)
for topic in topics:
last_time = topic.get("last_posted_at")
if not last_time:
continue
last_time_dt = time_str_to_utc(last_time)
if (start and start > last_time_dt) or (end and end < last_time_dt):
continue
topic_ids.append(topic["id"])
if len(topic_ids) >= self.batch_size:
break
return topic_ids
def _yield_discourse_documents(
self, topic_ids: list[int]
self,
start: datetime,
end: datetime,
) -> GenerateDocumentsOutput:
doc_batch: list[Document] = []
for topic_id in topic_ids:
doc_batch.append(self._get_doc_from_topic(topic_id))
page = 1
while topic_ids := self._get_latest_topics(start, end, page):
doc_batch: list[Document] = []
for topic_id in topic_ids:
doc_batch.append(self._get_doc_from_topic(topic_id))
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if len(doc_batch) >= self.batch_size:
if doc_batch:
yield doc_batch
doc_batch = []
page += 1
if doc_batch:
yield doc_batch
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
def load_credentials(
self,
credentials: dict[str, Any],
) -> dict[str, Any] | None:
self.permissions = DiscoursePerms(
api_key=credentials["discourse_api_key"],
api_username=credentials["discourse_api_username"],
)
return None
def poll_source(
@@ -184,16 +218,13 @@ class DiscourseConnector(PollConnector):
) -> GenerateDocumentsOutput:
if self.permissions is None:
raise ConnectorMissingCredentialError("Discourse")
start_datetime = datetime.utcfromtimestamp(start).replace(tzinfo=timezone.utc)
end_datetime = datetime.utcfromtimestamp(end).replace(tzinfo=timezone.utc)
self._get_categories_map()
latest_topic_ids = self._get_latest_topics(
start=start_datetime, end=end_datetime
)
yield from self._yield_discourse_documents(latest_topic_ids)
yield from self._yield_discourse_documents(start_datetime, end_datetime)
if __name__ == "__main__":
@@ -209,7 +240,5 @@ if __name__ == "__main__":
current = time.time()
one_year_ago = current - 24 * 60 * 60 * 360
latest_docs = connector.poll_source(one_year_ago, current)
print(next(latest_docs))

View File

@@ -1,8 +1,11 @@
from typing import Any
from typing import Type
from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource
from danswer.connectors.axero.connector import AxeroConnector
from danswer.connectors.blob.connector import BlobStorageConnector
from danswer.connectors.bookstack.connector import BookstackConnector
from danswer.connectors.clickup.connector import ClickupConnector
from danswer.connectors.confluence.connector import ConfluenceConnector
@@ -40,6 +43,8 @@ from danswer.connectors.web.connector import WebConnector
from danswer.connectors.wikipedia.connector import WikipediaConnector
from danswer.connectors.zendesk.connector import ZendeskConnector
from danswer.connectors.zulip.connector import ZulipConnector
from danswer.db.credentials import backend_update_credential_json
from danswer.db.models import Credential
class ConnectorMissingException(Exception):
@@ -86,6 +91,10 @@ def identify_connector_class(
DocumentSource.CLICKUP: ClickupConnector,
DocumentSource.MEDIAWIKI: MediaWikiConnector,
DocumentSource.WIKIPEDIA: WikipediaConnector,
DocumentSource.S3: BlobStorageConnector,
DocumentSource.R2: BlobStorageConnector,
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
DocumentSource.OCI_STORAGE: BlobStorageConnector,
}
connector_by_source = connector_map.get(source, {})
@@ -111,7 +120,6 @@ def identify_connector_class(
raise ConnectorMissingException(
f"Connector for source={source} does not accept input_type={input_type}"
)
return connector
@@ -119,10 +127,14 @@ def instantiate_connector(
source: DocumentSource,
input_type: InputType,
connector_specific_config: dict[str, Any],
credentials: dict[str, Any],
) -> tuple[BaseConnector, dict[str, Any] | None]:
credential: Credential,
db_session: Session,
) -> BaseConnector:
connector_class = identify_connector_class(source, input_type)
connector = connector_class(**connector_specific_config)
new_credentials = connector.load_credentials(credentials)
new_credentials = connector.load_credentials(credential.credential_json)
return connector, new_credentials
if new_credentials is not None:
backend_update_credential_json(credential, new_credentials, db_session)
return connector

View File

@@ -85,8 +85,18 @@ def _process_file(
all_metadata = {**metadata, **file_metadata} if metadata else file_metadata
# add a prefix to avoid conflicts with other connectors
doc_id = f"FILE_CONNECTOR__{file_name}"
if metadata:
doc_id = metadata.get("document_id") or doc_id
# If this is set, we will show this in the UI as the "name" of the file
file_display_name_override = all_metadata.get("file_display_name")
file_display_name = all_metadata.get("file_display_name") or os.path.basename(
file_name
)
title = (
all_metadata["title"] or "" if "title" in all_metadata else file_display_name
)
time_updated = all_metadata.get("time_updated", datetime.now(timezone.utc))
if isinstance(time_updated, str):
@@ -101,6 +111,7 @@ def _process_file(
for k, v in all_metadata.items()
if k
not in [
"document_id",
"time_updated",
"doc_updated_at",
"link",
@@ -108,6 +119,7 @@ def _process_file(
"secondary_owners",
"filename",
"file_display_name",
"title",
]
}
@@ -126,13 +138,13 @@ def _process_file(
return [
Document(
id=f"FILE_CONNECTOR__{file_name}", # add a prefix to avoid conflicts with other connectors
id=doc_id,
sections=[
Section(link=all_metadata.get("link"), text=file_content_raw.strip())
],
source=DocumentSource.FILE,
semantic_identifier=file_display_name_override
or os.path.basename(file_name),
semantic_identifier=file_display_name,
title=title,
doc_updated_at=final_time_updated,
primary_owners=p_owners,
secondary_owners=s_owners,

View File

@@ -473,6 +473,11 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
doc_batch = []
for file in files_batch:
try:
# Skip files that are shortcuts
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
logger.info("Ignoring Drive Shortcut Filetype")
continue
if self.only_org_public:
if "permissions" not in file:
continue

View File

@@ -50,6 +50,12 @@ class PollConnector(BaseConnector):
raise NotImplementedError
class IdConnector(BaseConnector):
@abc.abstractmethod
def retrieve_all_source_ids(self) -> set[str]:
raise NotImplementedError
# Event driven
class EventConnector(BaseConnector):
@abc.abstractmethod

View File

@@ -6,6 +6,7 @@ from pydantic import BaseModel
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import INDEX_SEPARATOR
from danswer.configs.constants import RETURN_SEPARATOR
from danswer.utils.text_processing import make_url_compatible
@@ -13,6 +14,7 @@ class InputType(str, Enum):
LOAD_STATE = "load_state" # e.g. loading a current full state or a save state, such as from a file
POLL = "poll" # e.g. calling an API to get all documents in the last hour
EVENT = "event" # e.g. registered an endpoint as a listener, and processing connector events
PRUNE = "prune"
class ConnectorMissingCredentialError(PermissionError):
@@ -116,7 +118,12 @@ class DocumentBase(BaseModel):
# If title is explicitly empty, return a None here for embedding purposes
if self.title == "":
return None
return self.semantic_identifier if self.title is None else self.title
replace_chars = set(RETURN_SEPARATOR)
title = self.semantic_identifier if self.title is None else self.title
for char in replace_chars:
title = title.replace(char, " ")
title = title.strip()
return title
def get_metadata_str_attributes(self) -> list[str] | None:
if not self.metadata:

View File

@@ -368,7 +368,7 @@ class NotionConnector(LoadConnector, PollConnector):
compare_time = time.mktime(
time.strptime(page[filter_field], "%Y-%m-%dT%H:%M:%S.000Z")
)
if compare_time <= end or compare_time > start:
if compare_time > start and compare_time <= end:
filtered_pages += [NotionPage(**page)]
return filtered_pages

View File

@@ -11,6 +11,7 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import IdConnector
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
@@ -23,11 +24,12 @@ from danswer.utils.logger import setup_logger
DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
MAX_QUERY_LENGTH = 10000 # max query length is 20,000 characters
ID_PREFIX = "SALESFORCE_"
logger = setup_logger()
class SalesforceConnector(LoadConnector, PollConnector):
class SalesforceConnector(LoadConnector, PollConnector, IdConnector):
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
@@ -77,8 +79,9 @@ class SalesforceConnector(LoadConnector, PollConnector):
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
extracted_id = f"SALESFORCE_{object_dict['Id']}"
extracted_link = f"https://{self.sf_client.sf_instance}/{extracted_id}"
salesforce_id = object_dict["Id"]
danswer_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
extracted_link = f"https://{self.sf_client.sf_instance}/{salesforce_id}"
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
extracted_object_text = extract_dict_text(object_dict)
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
@@ -89,7 +92,7 @@ class SalesforceConnector(LoadConnector, PollConnector):
]
doc = Document(
id=extracted_id,
id=danswer_salesforce_id,
sections=[Section(link=extracted_link, text=extracted_object_text)],
source=DocumentSource.SALESFORCE,
semantic_identifier=extracted_semantic_identifier,
@@ -229,8 +232,6 @@ class SalesforceConnector(LoadConnector, PollConnector):
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
return self._fetch_from_salesforce()
def poll_source(
@@ -242,6 +243,20 @@ class SalesforceConnector(LoadConnector, PollConnector):
end_datetime = datetime.utcfromtimestamp(end)
return self._fetch_from_salesforce(start=start_datetime, end=end_datetime)
def retrieve_all_source_ids(self) -> set[str]:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
all_retrieved_ids: set[str] = set()
for parent_object_type in self.parent_object_list:
query = f"SELECT Id FROM {parent_object_type}"
query_result = self.sf_client.query_all(query)
all_retrieved_ids.update(
f"{ID_PREFIX}{instance_dict.get('Id', '')}"
for instance_dict in query_result["records"]
)
return all_retrieved_ids
if __name__ == "__main__":
connector = SalesforceConnector(

View File

@@ -29,6 +29,7 @@ from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import pdf_to_text
from danswer.file_processing.html_utils import web_html_cleanup
from danswer.utils.logger import setup_logger
from danswer.utils.sitemap import list_pages_for_site
logger = setup_logger()
@@ -145,16 +146,21 @@ def extract_urls_from_sitemap(sitemap_url: str) -> list[str]:
response.raise_for_status()
soup = BeautifulSoup(response.content, "html.parser")
result = [
urls = [
_ensure_absolute_url(sitemap_url, loc_tag.text)
for loc_tag in soup.find_all("loc")
]
if not result:
if len(urls) == 0 and len(soup.find_all("urlset")) == 0:
# the given url doesn't look like a sitemap, let's try to find one
urls = list_pages_for_site(sitemap_url)
if len(urls) == 0:
raise ValueError(
f"No URLs found in sitemap {sitemap_url}. Try using the 'single' or 'recursive' scraping options instead."
)
return result
return urls
def _ensure_absolute_url(source_url: str, maybe_relative_url: str) -> str:
@@ -264,7 +270,7 @@ class WebConnector(LoadConnector):
id=current_url,
sections=[Section(link=current_url, text=page_text)],
source=DocumentSource.WEB,
semantic_identifier=current_url.split(".")[-1],
semantic_identifier=current_url.split("/")[-1],
metadata={},
)
)

View File

@@ -4,6 +4,7 @@ from zenpy import Zenpy # type: ignore
from zenpy.lib.api_objects.help_centre_objects import Article # type: ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
time_str_to_utc,
@@ -81,7 +82,14 @@ class ZendeskConnector(LoadConnector, PollConnector):
)
doc_batch = []
for article in articles:
if article.body is None or article.draft:
if (
article.body is None
or article.draft
or any(
label in ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
for label in article.label_names
)
):
continue
doc_batch.append(_article_to_document(article))

View File

@@ -25,6 +25,7 @@ from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
from danswer.danswerbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.icons import source_to_github_img_link
@@ -353,6 +354,22 @@ def build_quotes_block(
return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))]
def build_standard_answer_blocks(
answer_message: str,
) -> list[Block]:
generate_button_block = ButtonElement(
action_id=GENERATE_ANSWER_BUTTON_ACTION_ID,
text="Generate Full Answer",
)
answer_block = SectionBlock(text=answer_message)
return [
answer_block,
ActionsBlock(
elements=[generate_button_block],
),
]
def build_qa_response_blocks(
message_id: int | None,
answer: str | None,
@@ -457,7 +474,7 @@ def build_follow_up_resolved_blocks(
if tag_str:
tag_str += " "
group_str = " ".join([f"<!subteam^{group}>" for group in group_ids])
group_str = " ".join([f"<!subteam^{group_id}|>" for group_id in group_ids])
if group_str:
group_str += " "

View File

@@ -8,6 +8,7 @@ FOLLOWUP_BUTTON_ACTION_ID = "followup-button"
FOLLOWUP_BUTTON_RESOLVED_ACTION_ID = "followup-resolved-button"
SLACK_CHANNEL_ID = "channel_id"
VIEW_DOC_FEEDBACK_ID = "view-doc-feedback"
GENERATE_ANSWER_BUTTON_ACTION_ID = "generate-answer-button"
class FeedbackVisibility(str, Enum):

View File

@@ -1,3 +1,4 @@
import logging
from typing import Any
from typing import cast
@@ -8,6 +9,7 @@ from slack_sdk.socket_mode import SocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest
from sqlalchemy.orm import Session
from danswer.configs.constants import MessageType
from danswer.configs.constants import SearchFeedbackType
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
from danswer.connectors.slack.utils import make_slack_api_rate_limited
@@ -21,12 +23,17 @@ from danswer.danswerbot.slack.constants import VIEW_DOC_FEEDBACK_ID
from danswer.danswerbot.slack.handlers.handle_message import (
remove_scheduled_feedback_reminder,
)
from danswer.danswerbot.slack.handlers.handle_regular_answer import (
handle_regular_answer,
)
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import build_feedback_id
from danswer.danswerbot.slack.utils import decompose_action_id
from danswer.danswerbot.slack.utils import fetch_groupids_from_names
from danswer.danswerbot.slack.utils import fetch_userids_from_emails
from danswer.danswerbot.slack.utils import fetch_group_ids_from_names
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
from danswer.danswerbot.slack.utils import get_channel_name_from_id
from danswer.danswerbot.slack.utils import get_feedback_visibility
from danswer.danswerbot.slack.utils import read_slack_thread
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_sqlalchemy_engine
@@ -36,7 +43,7 @@ from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.utils.logger import setup_logger
logger_base = setup_logger()
logger = setup_logger()
def handle_doc_feedback_button(
@@ -44,7 +51,7 @@ def handle_doc_feedback_button(
client: SocketModeClient,
) -> None:
if not (actions := req.payload.get("actions")):
logger_base.error("Missing actions. Unable to build the source feedback view")
logger.error("Missing actions. Unable to build the source feedback view")
return
# Extracts the feedback_id coming from the 'source feedback' button
@@ -72,6 +79,66 @@ def handle_doc_feedback_button(
)
def handle_generate_answer_button(
req: SocketModeRequest,
client: SocketModeClient,
) -> None:
channel_id = req.payload["channel"]["id"]
channel_name = req.payload["channel"]["name"]
message_ts = req.payload["message"]["ts"]
thread_ts = req.payload["container"]["thread_ts"]
user_id = req.payload["user"]["id"]
if not thread_ts:
raise ValueError("Missing thread_ts in the payload")
thread_messages = read_slack_thread(
channel=channel_id, thread=thread_ts, client=client.web_client
)
# remove all assistant messages till we get to the last user message
# we want the new answer to be generated off of the last "question" in
# the thread
for i in range(len(thread_messages) - 1, -1, -1):
if thread_messages[i].role == MessageType.USER:
break
if thread_messages[i].role == MessageType.ASSISTANT:
thread_messages.pop(i)
# tell the user that we're working on it
# Send an ephemeral message to the user that we're generating the answer
respond_in_thread(
client=client.web_client,
channel=channel_id,
receiver_ids=[user_id],
text="I'm working on generating a full answer for you. This may take a moment...",
thread_ts=thread_ts,
)
with Session(get_sqlalchemy_engine()) as db_session:
slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session
)
handle_regular_answer(
message_info=SlackMessageInfo(
thread_messages=thread_messages,
channel_to_respond=channel_id,
msg_to_respond=cast(str, message_ts or thread_ts),
thread_to_respond=cast(str, thread_ts or message_ts),
sender=user_id or None,
bypass_filters=True,
is_bot_msg=False,
is_bot_dm=False,
),
slack_bot_config=slack_bot_config,
receiver_ids=None,
client=client.web_client,
channel=channel_id,
logger=cast(logging.Logger, logger),
feedback_reminder_id=None,
)
def handle_slack_feedback(
feedback_id: str,
feedback_type: str,
@@ -129,7 +196,7 @@ def handle_slack_feedback(
feedback=feedback,
)
else:
logger_base.error(f"Feedback type '{feedback_type}' not supported")
logger.error(f"Feedback type '{feedback_type}' not supported")
if get_feedback_visibility() == FeedbackVisibility.PRIVATE or feedback_type not in [
LIKE_BLOCK_ACTION_ID,
@@ -193,11 +260,11 @@ def handle_followup_button(
tag_names = slack_bot_config.channel_config.get("follow_up_tags")
remaining = None
if tag_names:
tag_ids, remaining = fetch_userids_from_emails(
tag_ids, remaining = fetch_user_ids_from_emails(
tag_names, client.web_client
)
if remaining:
group_ids, _ = fetch_groupids_from_names(remaining, client.web_client)
group_ids, _ = fetch_group_ids_from_names(remaining, client.web_client)
blocks = build_follow_up_resolved_blocks(tag_ids=tag_ids, group_ids=group_ids)
@@ -272,7 +339,7 @@ def handle_followup_resolved_button(
)
if not response.get("ok"):
logger_base.error("Unable to delete message for resolved")
logger.error("Unable to delete message for resolved")
if immediate:
msg_text = f"{clicker_name} has marked this question as resolved!"

View File

@@ -1,91 +1,34 @@
import datetime
import functools
import logging
from collections.abc import Callable
from typing import Any
from typing import cast
from typing import Optional
from typing import TypeVar
from retry import retry
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.models.blocks import DividerBlock
from slack_sdk.models.blocks import SectionBlock
from sqlalchemy.orm import Session
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
from danswer.configs.danswerbot_configs import DANSWER_BOT_FEEDBACK_REMINDER
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
from danswer.configs.danswerbot_configs import DISABLE_DANSWER_BOT_FILTER_DETECT
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
from danswer.danswerbot.slack.blocks import build_documents_blocks
from danswer.danswerbot.slack.blocks import build_follow_up_block
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
from danswer.danswerbot.slack.blocks import build_sources_blocks
from danswer.danswerbot.slack.blocks import get_feedback_reminder_blocks
from danswer.danswerbot.slack.blocks import get_restate_blocks
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
from danswer.danswerbot.slack.handlers.handle_regular_answer import (
handle_regular_answer,
)
from danswer.danswerbot.slack.handlers.handle_standard_answers import (
handle_standard_answers,
)
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import ChannelIdAdapter
from danswer.danswerbot.slack.utils import fetch_userids_from_emails
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
from danswer.danswerbot.slack.utils import fetch_user_ids_from_groups
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import slack_usage_report
from danswer.danswerbot.slack.utils import SlackRateLimiter
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import Persona
from danswer.db.models import SlackBotConfig
from danswer.db.models import SlackBotResponseType
from danswer.db.persona import fetch_persona_by_id
from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona,
)
from danswer.llm.factory import get_llm_for_persona
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_max_input_tokens
from danswer.one_shot_answer.answer_question import get_search_answer
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.search.models import BaseFilters
from danswer.search.models import OptionalSearchSetting
from danswer.search.models import RetrievalDetails
from danswer.utils.logger import setup_logger
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
logger_base = setup_logger()
srl = SlackRateLimiter()
RT = TypeVar("RT") # return type
def rate_limits(
client: WebClient, channel: str, thread_ts: Optional[str]
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
def decorator(func: Callable[..., RT]) -> Callable[..., RT]:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> RT:
if not srl.is_available():
func_randid, position = srl.init_waiter()
srl.notify(client, channel, position, thread_ts)
while not srl.is_available():
srl.waiter(func_randid)
srl.acquire_slot()
return func(*args, **kwargs)
return wrapper
return decorator
def send_msg_ack_to_user(details: SlackMessageInfo, client: WebClient) -> None:
if details.is_bot_msg and details.sender:
@@ -173,17 +116,9 @@ def remove_scheduled_feedback_reminder(
def handle_message(
message_info: SlackMessageInfo,
channel_config: SlackBotConfig | None,
slack_bot_config: SlackBotConfig | None,
client: WebClient,
feedback_reminder_id: str | None,
num_retries: int = DANSWER_BOT_NUM_RETRIES,
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
disable_docs_only_answer: bool = DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER,
disable_auto_detect_filters: bool = DISABLE_DANSWER_BOT_FILTER_DETECT,
reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
disable_cot: bool = DANSWER_BOT_DISABLE_COT,
thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE,
) -> bool:
"""Potentially respond to the user message depending on filters and if an answer was generated
@@ -200,14 +135,22 @@ def handle_message(
)
messages = message_info.thread_messages
message_ts_to_respond_to = message_info.msg_to_respond
sender_id = message_info.sender
bypass_filters = message_info.bypass_filters
is_bot_msg = message_info.is_bot_msg
is_bot_dm = message_info.is_bot_dm
action = "slack_message"
if is_bot_msg:
action = "slack_slash_message"
elif bypass_filters:
action = "slack_tag_message"
elif is_bot_dm:
action = "slack_dm_message"
slack_usage_report(action=action, sender_id=sender_id, client=client)
document_set_names: list[str] | None = None
persona = channel_config.persona if channel_config else None
persona = slack_bot_config.persona if slack_bot_config else None
prompt = None
if persona:
document_set_names = [
@@ -215,36 +158,13 @@ def handle_message(
]
prompt = persona.prompts[0] if persona.prompts else None
should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False
# figure out if we want to use citations or quotes
use_citations = (
not DANSWER_BOT_USE_QUOTES
if channel_config is None
else channel_config.response_type == SlackBotResponseType.CITATIONS
)
# List of user id to send message to, if None, send to everyone in channel
send_to: list[str] | None = None
respond_tag_only = False
respond_team_member_list = None
bypass_acl = False
if (
channel_config
and channel_config.persona
and channel_config.persona.document_sets
):
# For Slack channels, use the full document set, admin will be warned when configuring it
# with non-public document sets
bypass_acl = True
respond_member_group_list = None
channel_conf = None
if channel_config and channel_config.channel_config:
channel_conf = channel_config.channel_config
if slack_bot_config and slack_bot_config.channel_config:
channel_conf = slack_bot_config.channel_config
if not bypass_filters and "answer_filters" in channel_conf:
reflexion = "well_answered_postfilter" in channel_conf["answer_filters"]
if (
"questionmark_prefilter" in channel_conf["answer_filters"]
and "?" not in messages[-1].message
@@ -261,7 +181,7 @@ def handle_message(
)
respond_tag_only = channel_conf.get("respond_tag_only") or False
respond_team_member_list = channel_conf.get("respond_team_member_list") or None
respond_member_group_list = channel_conf.get("respond_member_group_list", None)
if respond_tag_only and not bypass_filters:
logger.info(
@@ -270,12 +190,23 @@ def handle_message(
)
return False
if respond_team_member_list:
send_to, _ = fetch_userids_from_emails(respond_team_member_list, client)
# List of user id to send message to, if None, send to everyone in channel
send_to: list[str] | None = None
missing_users: list[str] | None = None
if respond_member_group_list:
send_to, missing_ids = fetch_user_ids_from_emails(
respond_member_group_list, client
)
user_ids, missing_users = fetch_user_ids_from_groups(missing_ids, client)
send_to = list(set(send_to + user_ids)) if send_to else user_ids
if missing_users:
logger.warning(f"Failed to find these users/groups: {missing_users}")
# If configured to respond to team members only, then cannot be used with a /DanswerBot command
# which would just respond to the sender
if respond_team_member_list and is_bot_msg:
if send_to and is_bot_msg:
if sender_id:
respond_in_thread(
client=client,
@@ -290,324 +221,28 @@ def handle_message(
except SlackApiError as e:
logger.error(f"Was not able to react to user message due to: {e}")
@retry(
tries=num_retries,
delay=0.25,
backoff=2,
logger=logger,
)
@rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to)
def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | None:
action = "slack_message"
if is_bot_msg:
action = "slack_slash_message"
elif bypass_filters:
action = "slack_tag_message"
elif is_bot_dm:
action = "slack_dm_message"
slack_usage_report(action=action, sender_id=sender_id, client=client)
max_document_tokens: int | None = None
max_history_tokens: int | None = None
with Session(get_sqlalchemy_engine()) as db_session:
if len(new_message_request.messages) > 1:
persona = cast(
Persona,
fetch_persona_by_id(db_session, new_message_request.persona_id),
)
llm = get_llm_for_persona(persona)
# In cases of threads, split the available tokens between docs and thread context
input_tokens = get_max_input_tokens(
model_name=llm.config.model_name,
model_provider=llm.config.model_provider,
)
max_history_tokens = int(input_tokens * thread_context_percent)
remaining_tokens = input_tokens - max_history_tokens
query_text = new_message_request.messages[0].message
if persona:
max_document_tokens = compute_max_document_tokens_for_persona(
persona=persona,
actual_user_input=query_text,
max_llm_token_override=remaining_tokens,
)
else:
max_document_tokens = (
remaining_tokens
- 512 # Needs to be more than any of the QA prompts
- check_number_of_tokens(query_text)
)
if DISABLE_GENERATIVE_AI:
return None
# This also handles creating the query event in postgres
answer = get_search_answer(
query_req=new_message_request,
user=None,
max_document_tokens=max_document_tokens,
max_history_tokens=max_history_tokens,
db_session=db_session,
answer_generation_timeout=answer_generation_timeout,
enable_reflexion=reflexion,
bypass_acl=bypass_acl,
use_citations=use_citations,
danswerbot_flow=True,
)
if not answer.error_msg:
return answer
else:
raise RuntimeError(answer.error_msg)
try:
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
# it allows the slack flow to extract out filters from the user query
filters = BaseFilters(
source_type=None,
document_set=document_set_names,
time_cutoff=None,
with Session(get_sqlalchemy_engine()) as db_session:
# first check if we need to respond with a standard answer
used_standard_answer = handle_standard_answers(
message_info=message_info,
receiver_ids=send_to,
slack_bot_config=slack_bot_config,
prompt=prompt,
logger=logger,
client=client,
db_session=db_session,
)
# Default True because no other ways to apply filters in Slack (no nice UI)
auto_detect_filters = (
persona.llm_filter_extraction if persona is not None else True
)
if disable_auto_detect_filters:
auto_detect_filters = False
retrieval_details = RetrievalDetails(
run_search=OptionalSearchSetting.ALWAYS,
real_time=False,
filters=filters,
enable_auto_detect_filters=auto_detect_filters,
)
# This includes throwing out answer via reflexion
answer = _get_answer(
DirectQARequest(
messages=messages,
prompt_id=prompt.id if prompt else None,
persona_id=persona.id if persona is not None else 0,
retrieval_options=retrieval_details,
chain_of_thought=not disable_cot,
skip_rerank=not ENABLE_RERANKING_ASYNC_FLOW,
)
)
except Exception as e:
logger.exception(
f"Unable to process message - did not successfully answer "
f"in {num_retries} attempts"
)
# Optionally, respond in thread with the error message, Used primarily
# for debugging purposes
if should_respond_with_error_msgs:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=None,
text=f"Encountered exception when trying to answer: \n\n```{e}```",
thread_ts=message_ts_to_respond_to,
)
# In case of failures, don't keep the reaction there permanently
try:
update_emote_react(
emoji=DANSWER_REACT_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=True,
client=client,
)
except SlackApiError as e:
logger.error(f"Failed to remove Reaction due to: {e}")
return True
# Edge case handling, for tracking down the Slack usage issue
if answer is None:
assert DISABLE_GENERATIVE_AI is True
try:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=send_to,
text="Hello! Danswer has some results for you!",
blocks=[
SectionBlock(
text="Danswer is down for maintenance.\nWe're working hard on recharging the AI!"
)
],
thread_ts=message_ts_to_respond_to,
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
unfurl=False,
)
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
if respond_team_member_list:
respond_in_thread(
client=client,
channel=channel,
text=(
"👋 Hi, we've just gathered and forwarded the relevant "
+ "information to the team. They'll get back to you shortly!"
),
thread_ts=message_ts_to_respond_to,
)
if used_standard_answer:
return False
except Exception:
logger.exception(
f"Unable to process message - could not respond in slack in {num_retries} attempts"
)
return True
# Got an answer at this point, can remove reaction and give results
try:
update_emote_react(
emoji=DANSWER_REACT_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=True,
client=client,
)
except SlackApiError as e:
logger.error(f"Failed to remove Reaction due to: {e}")
if answer.answer_valid is False:
logger.info(
"Answer was evaluated to be invalid, throwing it away without responding."
)
update_emote_react(
emoji=DANSWER_FOLLOWUP_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=False,
client=client,
)
if answer.answer:
logger.debug(answer.answer)
return True
retrieval_info = answer.docs
if not retrieval_info:
# This should not happen, even with no docs retrieved, there is still info returned
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
top_docs = retrieval_info.top_documents
if not top_docs and not should_respond_even_with_no_docs:
logger.error(
f"Unable to answer question: '{answer.rephrase}' - no documents found"
)
# Optionally, respond in thread with the error message
# Used primarily for debugging purposes
if should_respond_with_error_msgs:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=None,
text="Found no documents when trying to answer. Did you index any documents?",
thread_ts=message_ts_to_respond_to,
)
return True
if not answer.answer and disable_docs_only_answer:
logger.info(
"Unable to find answer - not responding since the "
"`DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER` env variable is set"
)
return True
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
answer_blocks = build_qa_response_blocks(
message_id=answer.chat_message_id,
answer=answer.answer,
quotes=answer.quotes.quotes if answer.quotes else None,
source_filters=retrieval_info.applied_source_filters,
time_cutoff=retrieval_info.applied_time_cutoff,
favor_recent=retrieval_info.recency_bias_multiplier > 1,
# currently Personas don't support quotes
# if citations are enabled, also don't use quotes
skip_quotes=persona is not None or use_citations,
process_message_for_citations=use_citations,
feedback_reminder_id=feedback_reminder_id,
)
# Get the chunks fed to the LLM only, then fill with other docs
llm_doc_inds = answer.llm_chunks_indices or []
llm_docs = [top_docs[i] for i in llm_doc_inds]
remaining_docs = [
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
]
priority_ordered_docs = llm_docs + remaining_docs
document_blocks = []
citations_block = []
# if citations are enabled, only show cited documents
if use_citations:
citations = answer.citations or []
cited_docs = []
for citation in citations:
matching_doc = next(
(d for d in top_docs if d.document_id == citation.document_id),
None,
)
if matching_doc:
cited_docs.append((citation.citation_num, matching_doc))
cited_docs.sort()
citations_block = build_sources_blocks(cited_documents=cited_docs)
elif priority_ordered_docs:
document_blocks = build_documents_blocks(
documents=priority_ordered_docs,
message_id=answer.chat_message_id,
)
document_blocks = [DividerBlock()] + document_blocks
all_blocks = (
restate_question_block + answer_blocks + citations_block + document_blocks
)
if channel_conf and channel_conf.get("follow_up_tags") is not None:
all_blocks.append(build_follow_up_block(message_id=answer.chat_message_id))
try:
respond_in_thread(
# if no standard answer applies, try a regular answer
issue_with_regular_answer = handle_regular_answer(
message_info=message_info,
slack_bot_config=slack_bot_config,
receiver_ids=send_to,
client=client,
channel=channel,
receiver_ids=send_to,
text="Hello! Danswer has some results for you!",
blocks=all_blocks,
thread_ts=message_ts_to_respond_to,
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
unfurl=False,
logger=logger,
feedback_reminder_id=feedback_reminder_id,
)
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
if respond_team_member_list:
respond_in_thread(
client=client,
channel=channel,
text=(
"👋 Hi, we've just gathered and forwarded the relevant "
+ "information to the team. They'll get back to you shortly!"
),
thread_ts=message_ts_to_respond_to,
)
return False
except Exception:
logger.exception(
f"Unable to process message - could not respond in slack in {num_retries} attempts"
)
return True
return issue_with_regular_answer

View File

@@ -0,0 +1,465 @@
import functools
import logging
from collections.abc import Callable
from typing import Any
from typing import cast
from typing import Optional
from typing import TypeVar
from retry import retry
from slack_sdk import WebClient
from slack_sdk.models.blocks import DividerBlock
from slack_sdk.models.blocks import SectionBlock
from sqlalchemy.orm import Session
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
from danswer.danswerbot.slack.blocks import build_documents_blocks
from danswer.danswerbot.slack.blocks import build_follow_up_block
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
from danswer.danswerbot.slack.blocks import build_sources_blocks
from danswer.danswerbot.slack.blocks import get_restate_blocks
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import SlackRateLimiter
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import Persona
from danswer.db.models import SlackBotConfig
from danswer.db.models import SlackBotResponseType
from danswer.db.persona import fetch_persona_by_id
from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona,
)
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_max_input_tokens
from danswer.one_shot_answer.answer_question import get_search_answer
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.search.enums import OptionalSearchSetting
from danswer.search.models import BaseFilters
from danswer.search.models import RetrievalDetails
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
srl = SlackRateLimiter()
RT = TypeVar("RT") # return type
def rate_limits(
client: WebClient, channel: str, thread_ts: Optional[str]
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
def decorator(func: Callable[..., RT]) -> Callable[..., RT]:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> RT:
if not srl.is_available():
func_randid, position = srl.init_waiter()
srl.notify(client, channel, position, thread_ts)
while not srl.is_available():
srl.waiter(func_randid)
srl.acquire_slot()
return func(*args, **kwargs)
return wrapper
return decorator
def handle_regular_answer(
message_info: SlackMessageInfo,
slack_bot_config: SlackBotConfig | None,
receiver_ids: list[str] | None,
client: WebClient,
channel: str,
logger: logging.Logger,
feedback_reminder_id: str | None,
num_retries: int = DANSWER_BOT_NUM_RETRIES,
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE,
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
disable_docs_only_answer: bool = DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER,
disable_cot: bool = DANSWER_BOT_DISABLE_COT,
reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
) -> bool:
channel_conf = slack_bot_config.channel_config if slack_bot_config else None
messages = message_info.thread_messages
message_ts_to_respond_to = message_info.msg_to_respond
is_bot_msg = message_info.is_bot_msg
document_set_names: list[str] | None = None
persona = slack_bot_config.persona if slack_bot_config else None
prompt = None
if persona:
document_set_names = [
document_set.name for document_set in persona.document_sets
]
prompt = persona.prompts[0] if persona.prompts else None
should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False
bypass_acl = False
if (
slack_bot_config
and slack_bot_config.persona
and slack_bot_config.persona.document_sets
):
# For Slack channels, use the full document set, admin will be warned when configuring it
# with non-public document sets
bypass_acl = True
# figure out if we want to use citations or quotes
use_citations = (
not DANSWER_BOT_USE_QUOTES
if slack_bot_config is None
else slack_bot_config.response_type == SlackBotResponseType.CITATIONS
)
if not message_ts_to_respond_to:
raise RuntimeError(
"No message timestamp to respond to in `handle_message`. This should never happen."
)
@retry(
tries=num_retries,
delay=0.25,
backoff=2,
logger=logger,
)
@rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to)
def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | None:
max_document_tokens: int | None = None
max_history_tokens: int | None = None
with Session(get_sqlalchemy_engine()) as db_session:
if len(new_message_request.messages) > 1:
persona = cast(
Persona,
fetch_persona_by_id(db_session, new_message_request.persona_id),
)
llm, _ = get_llms_for_persona(persona)
# In cases of threads, split the available tokens between docs and thread context
input_tokens = get_max_input_tokens(
model_name=llm.config.model_name,
model_provider=llm.config.model_provider,
)
max_history_tokens = int(input_tokens * thread_context_percent)
remaining_tokens = input_tokens - max_history_tokens
query_text = new_message_request.messages[0].message
if persona:
max_document_tokens = compute_max_document_tokens_for_persona(
persona=persona,
actual_user_input=query_text,
max_llm_token_override=remaining_tokens,
)
else:
max_document_tokens = (
remaining_tokens
- 512 # Needs to be more than any of the QA prompts
- check_number_of_tokens(query_text)
)
if DISABLE_GENERATIVE_AI:
return None
# This also handles creating the query event in postgres
answer = get_search_answer(
query_req=new_message_request,
user=None,
max_document_tokens=max_document_tokens,
max_history_tokens=max_history_tokens,
db_session=db_session,
answer_generation_timeout=answer_generation_timeout,
enable_reflexion=reflexion,
bypass_acl=bypass_acl,
use_citations=use_citations,
danswerbot_flow=True,
)
if not answer.error_msg:
return answer
else:
raise RuntimeError(answer.error_msg)
try:
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
# it allows the slack flow to extract out filters from the user query
filters = BaseFilters(
source_type=None,
document_set=document_set_names,
time_cutoff=None,
)
# Default True because no other ways to apply filters in Slack (no nice UI)
# Commenting this out because this is only available to the slackbot for now
# later we plan to implement this at the persona level where this will get
# commented back in
# auto_detect_filters = (
# persona.llm_filter_extraction if persona is not None else True
# )
auto_detect_filters = (
slack_bot_config.enable_auto_filters
if slack_bot_config is not None
else False
)
retrieval_details = RetrievalDetails(
run_search=OptionalSearchSetting.ALWAYS,
real_time=False,
filters=filters,
enable_auto_detect_filters=auto_detect_filters,
)
# This includes throwing out answer via reflexion
answer = _get_answer(
DirectQARequest(
messages=messages,
prompt_id=prompt.id if prompt else None,
persona_id=persona.id if persona is not None else 0,
retrieval_options=retrieval_details,
chain_of_thought=not disable_cot,
skip_rerank=not ENABLE_RERANKING_ASYNC_FLOW,
)
)
except Exception as e:
logger.exception(
f"Unable to process message - did not successfully answer "
f"in {num_retries} attempts"
)
# Optionally, respond in thread with the error message, Used primarily
# for debugging purposes
if should_respond_with_error_msgs:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=None,
text=f"Encountered exception when trying to answer: \n\n```{e}```",
thread_ts=message_ts_to_respond_to,
)
# In case of failures, don't keep the reaction there permanently
update_emote_react(
emoji=DANSWER_REACT_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=True,
client=client,
)
return True
# Edge case handling, for tracking down the Slack usage issue
if answer is None:
assert DISABLE_GENERATIVE_AI is True
try:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=receiver_ids,
text="Hello! Danswer has some results for you!",
blocks=[
SectionBlock(
text="Danswer is down for maintenance.\nWe're working hard on recharging the AI!"
)
],
thread_ts=message_ts_to_respond_to,
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
unfurl=False,
)
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
if receiver_ids:
respond_in_thread(
client=client,
channel=channel,
text=(
"👋 Hi, we've just gathered and forwarded the relevant "
+ "information to the team. They'll get back to you shortly!"
),
thread_ts=message_ts_to_respond_to,
)
return False
except Exception:
logger.exception(
f"Unable to process message - could not respond in slack in {num_retries} attempts"
)
return True
# Got an answer at this point, can remove reaction and give results
update_emote_react(
emoji=DANSWER_REACT_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=True,
client=client,
)
if answer.answer_valid is False:
logger.info(
"Answer was evaluated to be invalid, throwing it away without responding."
)
update_emote_react(
emoji=DANSWER_FOLLOWUP_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=False,
client=client,
)
if answer.answer:
logger.debug(answer.answer)
return True
retrieval_info = answer.docs
if not retrieval_info:
# This should not happen, even with no docs retrieved, there is still info returned
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
top_docs = retrieval_info.top_documents
if not top_docs and not should_respond_even_with_no_docs:
logger.error(
f"Unable to answer question: '{answer.rephrase}' - no documents found"
)
# Optionally, respond in thread with the error message
# Used primarily for debugging purposes
if should_respond_with_error_msgs:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=None,
text="Found no documents when trying to answer. Did you index any documents?",
thread_ts=message_ts_to_respond_to,
)
return True
if not answer.answer and disable_docs_only_answer:
logger.info(
"Unable to find answer - not responding since the "
"`DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER` env variable is set"
)
return True
only_respond_with_citations_or_quotes = (
channel_conf
and "well_answered_postfilter" in channel_conf.get("answer_filters", [])
)
has_citations_or_quotes = bool(answer.citations or answer.quotes)
if (
only_respond_with_citations_or_quotes
and not has_citations_or_quotes
and not message_info.bypass_filters
):
logger.error(
f"Unable to find citations or quotes to answer: '{answer.rephrase}' - not answering!"
)
# Optionally, respond in thread with the error message
# Used primarily for debugging purposes
if should_respond_with_error_msgs:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=None,
text="Found no citations or quotes when trying to answer.",
thread_ts=message_ts_to_respond_to,
)
return True
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
answer_blocks = build_qa_response_blocks(
message_id=answer.chat_message_id,
answer=answer.answer,
quotes=answer.quotes.quotes if answer.quotes else None,
source_filters=retrieval_info.applied_source_filters,
time_cutoff=retrieval_info.applied_time_cutoff,
favor_recent=retrieval_info.recency_bias_multiplier > 1,
# currently Personas don't support quotes
# if citations are enabled, also don't use quotes
skip_quotes=persona is not None or use_citations,
process_message_for_citations=use_citations,
feedback_reminder_id=feedback_reminder_id,
)
# Get the chunks fed to the LLM only, then fill with other docs
llm_doc_inds = answer.llm_chunks_indices or []
llm_docs = [top_docs[i] for i in llm_doc_inds]
remaining_docs = [
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
]
priority_ordered_docs = llm_docs + remaining_docs
document_blocks = []
citations_block = []
# if citations are enabled, only show cited documents
if use_citations:
citations = answer.citations or []
cited_docs = []
for citation in citations:
matching_doc = next(
(d for d in top_docs if d.document_id == citation.document_id),
None,
)
if matching_doc:
cited_docs.append((citation.citation_num, matching_doc))
cited_docs.sort()
citations_block = build_sources_blocks(cited_documents=cited_docs)
elif priority_ordered_docs:
document_blocks = build_documents_blocks(
documents=priority_ordered_docs,
message_id=answer.chat_message_id,
)
document_blocks = [DividerBlock()] + document_blocks
all_blocks = (
restate_question_block + answer_blocks + citations_block + document_blocks
)
if channel_conf and channel_conf.get("follow_up_tags") is not None:
all_blocks.append(build_follow_up_block(message_id=answer.chat_message_id))
try:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=receiver_ids,
text="Hello! Danswer has some results for you!",
blocks=all_blocks,
thread_ts=message_ts_to_respond_to,
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
unfurl=False,
)
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
if receiver_ids:
send_team_member_message(
client=client,
channel=channel,
thread_ts=message_ts_to_respond_to,
)
return False
except Exception:
logger.exception(
f"Unable to process message - could not respond in slack in {num_retries} attempts"
)
return True

View File

@@ -0,0 +1,216 @@
import logging
from slack_sdk import WebClient
from sqlalchemy.orm import Session
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
from danswer.danswerbot.slack.blocks import build_standard_answer_blocks
from danswer.danswerbot.slack.blocks import get_restate_blocks
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.chat import create_chat_session
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_chat_messages_by_sessions
from danswer.db.chat import get_chat_sessions_by_slack_thread_id
from danswer.db.chat import get_or_create_root_message
from danswer.db.models import Prompt
from danswer.db.models import SlackBotConfig
from danswer.db.standard_answer import fetch_standard_answer_categories_by_names
from danswer.db.standard_answer import find_matching_standard_answers
from danswer.server.manage.models import StandardAnswer
from danswer.utils.logger import setup_logger
logger = setup_logger()
def oneoff_standard_answers(
message: str,
slack_bot_categories: list[str],
db_session: Session,
) -> list[StandardAnswer]:
"""
Respond to the user message if it matches any configured standard answers.
Returns a list of matching StandardAnswers if found, otherwise None.
"""
configured_standard_answers = {
standard_answer
for category in fetch_standard_answer_categories_by_names(
slack_bot_categories, db_session=db_session
)
for standard_answer in category.standard_answers
}
matching_standard_answers = find_matching_standard_answers(
query=message,
id_in=[answer.id for answer in configured_standard_answers],
db_session=db_session,
)
server_standard_answers = [
StandardAnswer.from_model(db_answer) for db_answer in matching_standard_answers
]
return server_standard_answers
def handle_standard_answers(
message_info: SlackMessageInfo,
receiver_ids: list[str] | None,
slack_bot_config: SlackBotConfig | None,
prompt: Prompt | None,
logger: logging.Logger,
client: WebClient,
db_session: Session,
) -> bool:
"""
Potentially respond to the user message depending on whether the user's message matches
any of the configured standard answers and also whether those answers have already been
provided in the current thread.
Returns True if standard answers are found to match the user's message and therefore,
we still need to respond to the users.
"""
# if no channel config, then no standard answers are configured
if not slack_bot_config:
return False
slack_thread_id = message_info.thread_to_respond
configured_standard_answer_categories = (
slack_bot_config.standard_answer_categories if slack_bot_config else []
)
configured_standard_answers = set(
[
standard_answer
for standard_answer_category in configured_standard_answer_categories
for standard_answer in standard_answer_category.standard_answers
]
)
query_msg = message_info.thread_messages[-1]
if slack_thread_id is None:
used_standard_answer_ids = set([])
else:
chat_sessions = get_chat_sessions_by_slack_thread_id(
slack_thread_id=slack_thread_id,
user_id=None,
db_session=db_session,
)
chat_messages = get_chat_messages_by_sessions(
chat_session_ids=[chat_session.id for chat_session in chat_sessions],
user_id=None,
db_session=db_session,
skip_permission_check=True,
)
used_standard_answer_ids = set(
[
standard_answer.id
for chat_message in chat_messages
for standard_answer in chat_message.standard_answers
]
)
usable_standard_answers = configured_standard_answers.difference(
used_standard_answer_ids
)
if usable_standard_answers:
matching_standard_answers = find_matching_standard_answers(
query=query_msg.message,
id_in=[standard_answer.id for standard_answer in usable_standard_answers],
db_session=db_session,
)
else:
matching_standard_answers = []
if matching_standard_answers:
chat_session = create_chat_session(
db_session=db_session,
description="",
user_id=None,
persona_id=slack_bot_config.persona.id if slack_bot_config.persona else 0,
danswerbot_flow=True,
slack_thread_id=slack_thread_id,
one_shot=True,
)
root_message = get_or_create_root_message(
chat_session_id=chat_session.id, db_session=db_session
)
new_user_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=root_message,
prompt_id=prompt.id if prompt else None,
message=query_msg.message,
token_count=0,
message_type=MessageType.USER,
db_session=db_session,
commit=True,
)
formatted_answers = []
for standard_answer in matching_standard_answers:
block_quotified_answer = ">" + standard_answer.answer.replace("\n", "\n> ")
formatted_answer = (
f'Since you mentioned _"{standard_answer.keyword}"_, '
f"I thought this might be useful: \n\n{block_quotified_answer}"
)
formatted_answers.append(formatted_answer)
answer_message = "\n\n".join(formatted_answers)
_ = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=new_user_message,
prompt_id=prompt.id if prompt else None,
message=answer_message,
token_count=0,
message_type=MessageType.ASSISTANT,
error=None,
db_session=db_session,
commit=True,
)
update_emote_react(
emoji=DANSWER_REACT_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=True,
client=client,
)
restate_question_blocks = get_restate_blocks(
msg=query_msg.message,
is_bot_msg=message_info.is_bot_msg,
)
answer_blocks = build_standard_answer_blocks(
answer_message=answer_message,
)
all_blocks = restate_question_blocks + answer_blocks
try:
respond_in_thread(
client=client,
channel=message_info.channel_to_respond,
receiver_ids=receiver_ids,
text="Hello! Danswer has some results for you!",
blocks=all_blocks,
thread_ts=message_info.msg_to_respond,
unfurl=False,
)
if receiver_ids and slack_thread_id:
send_team_member_message(
client=client,
channel=message_info.channel_to_respond,
thread_ts=slack_thread_id,
)
return True
except Exception as e:
logger.exception(f"Unable to send standard answer message: {e}")
return False
else:
return False

View File

@@ -0,0 +1,19 @@
from slack_sdk import WebClient
from danswer.danswerbot.slack.utils import respond_in_thread
def send_team_member_message(
client: WebClient,
channel: str,
thread_ts: str,
) -> None:
respond_in_thread(
client=client,
channel=channel,
text=(
"👋 Hi, we've just gathered and forwarded the relevant "
+ "information to the team. They'll get back to you shortly!"
),
thread_ts=thread_ts,
)

View File

@@ -4,9 +4,9 @@ from danswer.configs.constants import DocumentSource
def source_to_github_img_link(source: DocumentSource) -> str | None:
# TODO: store these images somewhere better
if source == DocumentSource.WEB.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/Web.png"
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Web.png"
if source == DocumentSource.FILE.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/File.png"
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/File.png"
if source == DocumentSource.GOOGLE_SITES.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/GoogleSites.png"
if source == DocumentSource.SLACK.value:
@@ -20,13 +20,13 @@ def source_to_github_img_link(source: DocumentSource) -> str | None:
if source == DocumentSource.GITLAB.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Gitlab.png"
if source == DocumentSource.CONFLUENCE.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/Confluence.png"
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Confluence.png"
if source == DocumentSource.JIRA.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/Jira.png"
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Jira.png"
if source == DocumentSource.NOTION.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Notion.png"
if source == DocumentSource.ZENDESK.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/Zendesk.png"
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Zendesk.png"
if source == DocumentSource.GONG.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Gong.png"
if source == DocumentSource.LINEAR.value:
@@ -38,7 +38,7 @@ def source_to_github_img_link(source: DocumentSource) -> str | None:
if source == DocumentSource.ZULIP.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Zulip.png"
if source == DocumentSource.GURU.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/Guru.png"
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Guru.png"
if source == DocumentSource.HUBSPOT.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/HubSpot.png"
if source == DocumentSource.DOCUMENT360.value:
@@ -51,8 +51,8 @@ def source_to_github_img_link(source: DocumentSource) -> str | None:
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Sharepoint.png"
if source == DocumentSource.REQUESTTRACKER.value:
# just use file icon for now
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/File.png"
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/File.png"
if source == DocumentSource.INGESTION_API.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/File.png"
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/File.png"
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/File.png"
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/File.png"

View File

@@ -18,6 +18,7 @@ from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
from danswer.danswerbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
@@ -27,6 +28,9 @@ from danswer.danswerbot.slack.handlers.handle_buttons import handle_followup_but
from danswer.danswerbot.slack.handlers.handle_buttons import (
handle_followup_resolved_button,
)
from danswer.danswerbot.slack.handlers.handle_buttons import (
handle_generate_answer_button,
)
from danswer.danswerbot.slack.handlers.handle_buttons import handle_slack_feedback
from danswer.danswerbot.slack.handlers.handle_message import handle_message
from danswer.danswerbot.slack.handlers.handle_message import (
@@ -56,7 +60,6 @@ from shared_configs.configs import MODEL_SERVER_PORT
logger = setup_logger()
# In rare cases, some users have been experiencing a massive amount of trivial messages coming through
# to the Slack Bot with trivial messages. Adding this to avoid exploding LLM costs while we track down
# the cause.
@@ -70,6 +73,9 @@ _SLACK_GREETINGS_TO_IGNORE = {
":wave:",
}
# this is always (currently) the user id of Slack's official slackbot
_OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool:
"""True to keep going, False to ignore this Slack request"""
@@ -92,6 +98,15 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
channel_specific_logger.error("Cannot respond to empty message - skipping")
return False
if (
req.payload.setdefault("event", {}).get("user", "")
== _OFFICIAL_SLACKBOT_USER_ID
):
channel_specific_logger.info(
"Ignoring messages from Slack's official Slackbot"
)
return False
if (
msg in _SLACK_GREETINGS_TO_IGNORE
or remove_danswer_bot_tag(msg, client=client.web_client)
@@ -255,6 +270,7 @@ def build_request_details(
thread_messages=thread_messages,
channel_to_respond=channel,
msg_to_respond=cast(str, message_ts or thread_ts),
thread_to_respond=cast(str, thread_ts or message_ts),
sender=event.get("user") or None,
bypass_filters=tagged,
is_bot_msg=False,
@@ -272,6 +288,7 @@ def build_request_details(
thread_messages=[single_msg],
channel_to_respond=channel,
msg_to_respond=None,
thread_to_respond=None,
sender=sender,
bypass_filters=True,
is_bot_msg=True,
@@ -341,7 +358,7 @@ def process_message(
failed = handle_message(
message_info=details,
channel_config=slack_bot_config,
slack_bot_config=slack_bot_config,
client=client.web_client,
feedback_reminder_id=feedback_reminder_id,
)
@@ -379,6 +396,8 @@ def action_routing(req: SocketModeRequest, client: SocketModeClient) -> None:
return handle_followup_resolved_button(req, client, immediate=True)
elif action["action_id"] == FOLLOWUP_BUTTON_RESOLVED_ACTION_ID:
return handle_followup_resolved_button(req, client, immediate=False)
elif action["action_id"] == GENERATE_ANSWER_BUTTON_ACTION_ID:
return handle_generate_answer_button(req, client)
def view_routing(req: SocketModeRequest, client: SocketModeClient) -> None:
@@ -450,13 +469,13 @@ if __name__ == "__main__":
# or the tokens have updated (set up for the first time)
with Session(get_sqlalchemy_engine()) as db_session:
embedding_model = get_current_db_embedding_model(db_session)
warm_up_encoders(
model_name=embedding_model.model_name,
normalize=embedding_model.normalize,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
if embedding_model.cloud_provider_id is None:
warm_up_encoders(
model_name=embedding_model.model_name,
normalize=embedding_model.normalize,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
slack_bot_tokens = latest_slack_bot_tokens
# potentially may cause a message to be dropped, but it is complicated

View File

@@ -7,6 +7,7 @@ class SlackMessageInfo(BaseModel):
thread_messages: list[ThreadMessage]
channel_to_respond: str
msg_to_respond: str | None
thread_to_respond: str | None
sender: str | None
bypass_filters: bool # User has tagged @DanswerBot
is_bot_msg: bool # User is using /DanswerBot

View File

@@ -30,7 +30,7 @@ from danswer.danswerbot.slack.tokens import fetch_tokens
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.users import get_user_by_email
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.factory import get_default_llms
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import message_to_string
from danswer.one_shot_answer.models import ThreadMessage
@@ -58,7 +58,7 @@ def rephrase_slack_message(msg: str) -> str:
return messages
try:
llm = get_default_llm(use_fast_llm=False, timeout=5)
llm, _ = get_default_llms(timeout=5)
except GenAIDisabledException:
logger.warning("Unable to rephrase Slack user message, Gen AI disabled")
return msg
@@ -77,17 +77,25 @@ def update_emote_react(
remove: bool,
client: WebClient,
) -> None:
if not message_ts:
logger.error(f"Tried to remove a react in {channel} but no message specified")
return
try:
if not message_ts:
logger.error(
f"Tried to remove a react in {channel} but no message specified"
)
return
func = client.reactions_remove if remove else client.reactions_add
slack_call = make_slack_api_rate_limited(func) # type: ignore
slack_call(
name=emoji,
channel=channel,
timestamp=message_ts,
)
func = client.reactions_remove if remove else client.reactions_add
slack_call = make_slack_api_rate_limited(func) # type: ignore
slack_call(
name=emoji,
channel=channel,
timestamp=message_ts,
)
except SlackApiError as e:
if remove:
logger.error(f"Failed to remove Reaction due to: {e}")
else:
logger.error(f"Was not able to react to user message due to: {e}")
def get_danswer_bot_app_id(web_client: WebClient) -> Any:
@@ -136,16 +144,13 @@ def respond_in_thread(
receiver_ids: list[str] | None = None,
metadata: Metadata | None = None,
unfurl: bool = True,
) -> None:
) -> list[str]:
if not text and not blocks:
raise ValueError("One of `text` or `blocks` must be provided")
message_ids: list[str] = []
if not receiver_ids:
slack_call = make_slack_api_rate_limited(client.chat_postMessage)
else:
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
if not receiver_ids:
response = slack_call(
channel=channel,
text=text,
@@ -157,7 +162,9 @@ def respond_in_thread(
)
if not response.get("ok"):
raise RuntimeError(f"Failed to post message: {response}")
message_ids.append(response["message_ts"])
else:
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
for receiver in receiver_ids:
response = slack_call(
channel=channel,
@@ -171,6 +178,9 @@ def respond_in_thread(
)
if not response.get("ok"):
raise RuntimeError(f"Failed to post message: {response}")
message_ids.append(response["message_ts"])
return message_ids
def build_feedback_id(
@@ -292,7 +302,7 @@ def get_channel_name_from_id(
raise e
def fetch_userids_from_emails(
def fetch_user_ids_from_emails(
user_emails: list[str], client: WebClient
) -> tuple[list[str], list[str]]:
user_ids: list[str] = []
@@ -308,32 +318,72 @@ def fetch_userids_from_emails(
return user_ids, failed_to_find
def fetch_groupids_from_names(
names: list[str], client: WebClient
def fetch_user_ids_from_groups(
given_names: list[str], client: WebClient
) -> tuple[list[str], list[str]]:
group_ids: set[str] = set()
user_ids: list[str] = []
failed_to_find: list[str] = []
try:
response = client.usergroups_list()
if not isinstance(response.data, dict):
logger.error("Error fetching user groups")
return user_ids, given_names
all_group_data = response.data.get("usergroups", [])
name_id_map = {d["name"]: d["id"] for d in all_group_data}
handle_id_map = {d["handle"]: d["id"] for d in all_group_data}
for given_name in given_names:
group_id = name_id_map.get(given_name) or handle_id_map.get(
given_name.lstrip("@")
)
if not group_id:
failed_to_find.append(given_name)
continue
try:
response = client.usergroups_users_list(usergroup=group_id)
if isinstance(response.data, dict):
user_ids.extend(response.data.get("users", []))
else:
failed_to_find.append(given_name)
except Exception as e:
logger.error(f"Error fetching user group ids: {str(e)}")
failed_to_find.append(given_name)
except Exception as e:
logger.error(f"Error fetching user groups: {str(e)}")
failed_to_find = given_names
return user_ids, failed_to_find
def fetch_group_ids_from_names(
given_names: list[str], client: WebClient
) -> tuple[list[str], list[str]]:
group_data: list[str] = []
failed_to_find: list[str] = []
try:
response = client.usergroups_list()
if response.get("ok") and "usergroups" in response.data:
all_groups_dicts = response.data["usergroups"] # type: ignore
name_id_map = {d["name"]: d["id"] for d in all_groups_dicts}
handle_id_map = {d["handle"]: d["id"] for d in all_groups_dicts}
for group in names:
if group in name_id_map:
group_ids.add(name_id_map[group])
elif group in handle_id_map:
group_ids.add(handle_id_map[group])
else:
failed_to_find.append(group)
else:
# Most likely a Slack App scope issue
if not isinstance(response.data, dict):
logger.error("Error fetching user groups")
return group_data, given_names
all_group_data = response.data.get("usergroups", [])
name_id_map = {d["name"]: d["id"] for d in all_group_data}
handle_id_map = {d["handle"]: d["id"] for d in all_group_data}
for given_name in given_names:
id = handle_id_map.get(given_name.lstrip("@"))
id = id or name_id_map.get(given_name)
if id:
group_data.append(id)
else:
failed_to_find.append(given_name)
except Exception as e:
failed_to_find = given_names
logger.error(f"Error fetching user groups: {str(e)}")
return list(group_ids), failed_to_find
return group_data, failed_to_find
def fetch_user_semantic_id_from_id(

View File

@@ -1,4 +1,5 @@
from collections.abc import AsyncGenerator
from collections.abc import Callable
from typing import Any
from typing import Dict
@@ -16,6 +17,20 @@ from danswer.db.engine import get_sqlalchemy_async_engine
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
def get_default_admin_user_emails() -> list[str]:
"""Returns a list of emails who should default to Admin role.
Only used in the EE version. For MIT, just return empty list."""
get_default_admin_user_emails_fn: Callable[
[], list[str]
] = fetch_versioned_implementation_with_fallback(
"danswer.auth.users", "get_default_admin_user_emails_", lambda: []
)
return get_default_admin_user_emails_fn()
async def get_user_count() -> int:
@@ -32,7 +47,7 @@ async def get_user_count() -> int:
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase):
async def create(self, create_dict: Dict[str, Any]) -> UP:
user_count = await get_user_count()
if user_count == 0:
if user_count == 0 or create_dict["email"] in get_default_admin_user_emails():
create_dict["role"] = UserRole.ADMIN
else:
create_dict["role"] = UserRole.BASIC

View File

@@ -1,17 +1,26 @@
from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from uuid import UUID
from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import nullsfirst
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from danswer.auth.schemas import UserRole
from danswer.chat.models import LLMRelevanceSummaryResponse
from danswer.configs.chat_configs import HARD_DELETE_CHATS
from danswer.configs.constants import MessageType
from danswer.db.models import ChatMessage
from danswer.db.models import ChatMessage__SearchDoc
from danswer.db.models import ChatSession
from danswer.db.models import ChatSessionSharedStatus
from danswer.db.models import Prompt
@@ -19,6 +28,7 @@ from danswer.db.models import SearchDoc
from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
from danswer.db.pg_file_store import delete_lobj_by_name
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
@@ -29,6 +39,7 @@ from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -63,17 +74,59 @@ def get_chat_session_by_id(
return chat_session
def get_chat_sessions_by_slack_thread_id(
slack_thread_id: str,
user_id: UUID | None,
db_session: Session,
) -> Sequence[ChatSession]:
stmt = select(ChatSession).where(ChatSession.slack_thread_id == slack_thread_id)
if user_id is not None:
stmt = stmt.where(
or_(ChatSession.user_id == user_id, ChatSession.user_id.is_(None))
)
return db_session.scalars(stmt).all()
def get_first_messages_for_chat_sessions(
chat_session_ids: list[int], db_session: Session
) -> dict[int, str]:
subquery = (
select(ChatMessage.chat_session_id, func.min(ChatMessage.id).label("min_id"))
.where(
and_(
ChatMessage.chat_session_id.in_(chat_session_ids),
ChatMessage.message_type == MessageType.USER, # Select USER messages
)
)
.group_by(ChatMessage.chat_session_id)
.subquery()
)
query = select(ChatMessage.chat_session_id, ChatMessage.message).join(
subquery,
(ChatMessage.chat_session_id == subquery.c.chat_session_id)
& (ChatMessage.id == subquery.c.min_id),
)
first_messages = db_session.execute(query).all()
return dict([(row.chat_session_id, row.message) for row in first_messages])
def get_chat_sessions_by_user(
user_id: UUID | None,
deleted: bool | None,
db_session: Session,
include_one_shot: bool = False,
only_one_shot: bool = False,
) -> list[ChatSession]:
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
if not include_one_shot:
if only_one_shot:
stmt = stmt.where(ChatSession.one_shot.is_(True))
else:
stmt = stmt.where(ChatSession.one_shot.is_(False))
stmt = stmt.order_by(desc(ChatSession.time_created))
if deleted is not None:
stmt = stmt.where(ChatSession.deleted == deleted)
@@ -83,15 +136,71 @@ def get_chat_sessions_by_user(
return list(chat_sessions)
def delete_search_doc_message_relationship(
message_id: int, db_session: Session
) -> None:
db_session.query(ChatMessage__SearchDoc).filter(
ChatMessage__SearchDoc.chat_message_id == message_id
).delete(synchronize_session=False)
db_session.commit()
def delete_tool_call_for_message_id(message_id: int, db_session: Session) -> None:
stmt = delete(ToolCall).where(ToolCall.message_id == message_id)
db_session.execute(stmt)
db_session.commit()
def delete_orphaned_search_docs(db_session: Session) -> None:
orphaned_docs = (
db_session.query(SearchDoc)
.outerjoin(ChatMessage__SearchDoc)
.filter(ChatMessage__SearchDoc.chat_message_id.is_(None))
.all()
)
for doc in orphaned_docs:
db_session.delete(doc)
db_session.commit()
def delete_messages_and_files_from_chat_session(
chat_session_id: int, db_session: Session
) -> None:
# Select messages older than cutoff_time with files
messages_with_files = db_session.execute(
select(ChatMessage.id, ChatMessage.files).where(
ChatMessage.chat_session_id == chat_session_id,
)
).fetchall()
for id, files in messages_with_files:
delete_tool_call_for_message_id(message_id=id, db_session=db_session)
delete_search_doc_message_relationship(message_id=id, db_session=db_session)
for file_info in files or {}:
lobj_name = file_info.get("id")
if lobj_name:
logger.info(f"Deleting file with name: {lobj_name}")
delete_lobj_by_name(lobj_name, db_session)
db_session.execute(
delete(ChatMessage).where(ChatMessage.chat_session_id == chat_session_id)
)
db_session.commit()
delete_orphaned_search_docs(db_session)
def create_chat_session(
db_session: Session,
description: str,
user_id: UUID | None,
persona_id: int | None = None,
persona_id: int,
llm_override: LLMOverride | None = None,
prompt_override: PromptOverride | None = None,
one_shot: bool = False,
danswerbot_flow: bool = False,
slack_thread_id: str | None = None,
) -> ChatSession:
chat_session = ChatSession(
user_id=user_id,
@@ -101,6 +210,7 @@ def create_chat_session(
prompt_override=prompt_override,
one_shot=one_shot,
danswerbot_flow=danswerbot_flow,
slack_thread_id=slack_thread_id,
)
db_session.add(chat_session)
@@ -139,25 +249,30 @@ def delete_chat_session(
db_session: Session,
hard_delete: bool = HARD_DELETE_CHATS,
) -> None:
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
)
if hard_delete:
stmt_messages = delete(ChatMessage).where(
ChatMessage.chat_session_id == chat_session_id
)
db_session.execute(stmt_messages)
stmt = delete(ChatSession).where(ChatSession.id == chat_session_id)
db_session.execute(stmt)
delete_messages_and_files_from_chat_session(chat_session_id, db_session)
db_session.execute(delete(ChatSession).where(ChatSession.id == chat_session_id))
else:
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
)
chat_session.deleted = True
db_session.commit()
def delete_chat_sessions_older_than(days_old: int, db_session: Session) -> None:
cutoff_time = datetime.utcnow() - timedelta(days=days_old)
old_sessions = db_session.execute(
select(ChatSession.user_id, ChatSession.id).where(
ChatSession.time_created < cutoff_time
)
).fetchall()
for user_id, session_id in old_sessions:
delete_chat_session(user_id, session_id, db_session, hard_delete=True)
def get_chat_message(
chat_message_id: int,
user_id: UUID | None,
@@ -183,6 +298,39 @@ def get_chat_message(
return chat_message
def get_chat_messages_by_sessions(
chat_session_ids: list[int],
user_id: UUID | None,
db_session: Session,
skip_permission_check: bool = False,
) -> Sequence[ChatMessage]:
if not skip_permission_check:
for chat_session_id in chat_session_ids:
get_chat_session_by_id(
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
)
stmt = (
select(ChatMessage)
.where(ChatMessage.chat_session_id.in_(chat_session_ids))
.order_by(nullsfirst(ChatMessage.parent_message))
)
return db_session.execute(stmt).scalars().all()
def get_search_docs_for_chat_message(
chat_message_id: int, db_session: Session
) -> list[SearchDoc]:
stmt = (
select(SearchDoc)
.join(
ChatMessage__SearchDoc, ChatMessage__SearchDoc.search_doc_id == SearchDoc.id
)
.where(ChatMessage__SearchDoc.chat_message_id == chat_message_id)
)
return list(db_session.scalars(stmt).all())
def get_chat_messages_by_session(
chat_session_id: int,
user_id: UUID | None,
@@ -203,8 +351,6 @@ def get_chat_messages_by_session(
if prefetch_tool_calls:
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
if prefetch_tool_calls:
result = db_session.scalars(stmt).unique().all()
else:
result = db_session.scalars(stmt).all()
@@ -259,6 +405,7 @@ def create_new_chat_message(
rephrased_query: str | None = None,
error: str | None = None,
reference_docs: list[DBSearchDoc] | None = None,
alternate_assistant_id: int | None = None,
# Maps the citation number [n] to the DB SearchDoc
citations: dict[int, int] | None = None,
tool_calls: list[ToolCall] | None = None,
@@ -277,6 +424,7 @@ def create_new_chat_message(
files=files,
tool_calls=tool_calls if tool_calls else [],
error=error,
alternate_assistant_id=alternate_assistant_id,
)
# SQL Alchemy will propagate this to update the reference_docs' foreign keys
@@ -371,16 +519,46 @@ def get_doc_query_identifiers_from_model(
)
raise ValueError("Docs references do not belong to user")
if any(
[doc.chat_messages[0].chat_session_id != chat_session.id for doc in search_docs]
):
raise ValueError("Invalid reference doc, not from this chat session.")
try:
if any(
[
doc.chat_messages[0].chat_session_id != chat_session.id
for doc in search_docs
]
):
raise ValueError("Invalid reference doc, not from this chat session.")
except IndexError:
# This happens when the doc has no chat_messages associated with it.
# which happens as an edge case where the chat message failed to save
# This usually happens when the LLM fails either immediately or partially through.
raise RuntimeError("Chat session failed, please start a new session.")
doc_query_identifiers = [(doc.document_id, doc.chunk_ind) for doc in search_docs]
return doc_query_identifiers
def update_search_docs_table_with_relevance(
db_session: Session,
reference_db_search_docs: list[SearchDoc],
relevance_summary: LLMRelevanceSummaryResponse,
) -> None:
for search_doc in reference_db_search_docs:
relevance_data = relevance_summary.relevance_summaries.get(
f"{search_doc.document_id}-{search_doc.chunk_ind}"
)
if relevance_data is not None:
db_session.execute(
update(SearchDoc)
.where(SearchDoc.id == search_doc.id)
.values(
is_relevant=relevance_data.relevant,
relevance_explanation=relevance_data.content,
)
)
db_session.commit()
def create_db_search_doc(
server_search_doc: ServerSearchDoc,
db_session: Session,
@@ -395,17 +573,19 @@ def create_db_search_doc(
boost=server_search_doc.boost,
hidden=server_search_doc.hidden,
doc_metadata=server_search_doc.metadata,
is_relevant=server_search_doc.is_relevant,
relevance_explanation=server_search_doc.relevance_explanation,
# For docs further down that aren't reranked, we can't use the retrieval score
score=server_search_doc.score or 0.0,
match_highlights=server_search_doc.match_highlights,
updated_at=server_search_doc.updated_at,
primary_owners=server_search_doc.primary_owners,
secondary_owners=server_search_doc.secondary_owners,
is_internet=server_search_doc.is_internet,
)
db_session.add(db_search_doc)
db_session.commit()
return db_search_doc
@@ -431,14 +611,17 @@ def translate_db_search_doc_to_server_search_doc(
hidden=db_search_doc.hidden,
metadata=db_search_doc.doc_metadata if not remove_doc_content else {},
score=db_search_doc.score,
match_highlights=db_search_doc.match_highlights
if not remove_doc_content
else [],
match_highlights=(
db_search_doc.match_highlights if not remove_doc_content else []
),
relevance_explanation=db_search_doc.relevance_explanation,
is_relevant=db_search_doc.is_relevant,
updated_at=db_search_doc.updated_at if not remove_doc_content else None,
primary_owners=db_search_doc.primary_owners if not remove_doc_content else [],
secondary_owners=db_search_doc.secondary_owners
if not remove_doc_content
else [],
secondary_owners=(
db_search_doc.secondary_owners if not remove_doc_content else []
),
is_internet=db_search_doc.is_internet,
)
@@ -456,9 +639,11 @@ def get_retrieval_docs_from_chat_message(
def translate_db_message_to_chat_message_detail(
chat_message: ChatMessage, remove_doc_content: bool = False
chat_message: ChatMessage,
remove_doc_content: bool = False,
) -> ChatMessageDetail:
chat_msg_detail = ChatMessageDetail(
chat_session_id=chat_message.chat_session_id,
message_id=chat_message.id,
parent_message=chat_message.parent_message,
latest_child_message=chat_message.latest_child_message,
@@ -479,6 +664,7 @@ def translate_db_message_to_chat_message_detail(
)
for tool_call in chat_message.tool_calls
],
alternate_assistant_id=chat_message.alternate_assistant_id,
)
return chat_msg_detail

View File

@@ -7,6 +7,7 @@ from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import Session
from danswer.configs.app_configs import DEFAULT_PRUNING_FREQ
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import InputType
from danswer.db.models import Connector
@@ -84,6 +85,9 @@ def create_connector(
input_type=connector_data.input_type,
connector_specific_config=connector_data.connector_specific_config,
refresh_freq=connector_data.refresh_freq,
prune_freq=connector_data.prune_freq
if connector_data.prune_freq is not None
else DEFAULT_PRUNING_FREQ,
disabled=connector_data.disabled,
)
db_session.add(connector)
@@ -113,6 +117,11 @@ def update_connector(
connector.input_type = connector_data.input_type
connector.connector_specific_config = connector_data.connector_specific_config
connector.refresh_freq = connector_data.refresh_freq
connector.prune_freq = (
connector_data.prune_freq
if connector_data.prune_freq is not None
else DEFAULT_PRUNING_FREQ
)
connector.disabled = connector_data.disabled
db_session.commit()
@@ -259,6 +268,7 @@ def create_initial_default_connector(db_session: Session) -> None:
input_type=InputType.LOAD_STATE,
connector_specific_config={},
refresh_freq=None,
prune_freq=None,
)
db_session.add(connector)
db_session.commit()

View File

@@ -151,7 +151,8 @@ def add_credential_to_connector(
connector_id: int,
credential_id: int,
cc_pair_name: str | None,
user: User,
is_public: bool,
user: User | None,
db_session: Session,
) -> StatusResponse[int]:
connector = fetch_connector_by_id(connector_id, db_session)
@@ -185,6 +186,7 @@ def add_credential_to_connector(
connector_id=connector_id,
credential_id=credential_id,
name=cc_pair_name,
is_public=is_public,
)
db_session.add(association)
db_session.commit()
@@ -199,7 +201,7 @@ def add_credential_to_connector(
def remove_credential_from_connector(
connector_id: int,
credential_id: int,
user: User,
user: User | None,
db_session: Session,
) -> StatusResponse[int]:
connector = fetch_connector_by_id(connector_id, db_session)

View File

@@ -12,6 +12,7 @@ from danswer.connectors.gmail.constants import (
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential
from danswer.db.models import User
from danswer.server.documents.models import CredentialBase
@@ -142,6 +143,18 @@ def delete_credential(
f"Credential by provided id {credential_id} does not exist or does not belong to user"
)
associated_connectors = (
db_session.query(ConnectorCredentialPair)
.filter(ConnectorCredentialPair.credential_id == credential_id)
.all()
)
if associated_connectors:
raise ValueError(
f"Cannot delete credential {credential_id} as it is still associated with {len(associated_connectors)} connector(s). "
"Please delete all associated connectors first."
)
db_session.delete(credential)
db_session.commit()

View File

@@ -327,6 +327,7 @@ def prepare_to_modify_documents(
Multiple commits will result in a sqlalchemy.exc.InvalidRequestError.
NOTE: this function will commit any existing transaction.
"""
db_session.commit() # ensure that we're not in a transaction
lock_acquired = False

View File

@@ -10,10 +10,15 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.db.llm import fetch_embedding_provider
from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexModelStatus
from danswer.indexing.models import EmbeddingModelDetail
from danswer.search.search_nlp_models import clean_model_name
from danswer.server.manage.embedding.models import (
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
)
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -31,6 +36,7 @@ def create_embedding_model(
query_prefix=model_details.query_prefix,
passage_prefix=model_details.passage_prefix,
status=status,
cloud_provider_id=model_details.cloud_provider_id,
# Every single embedding model except the initial one from migrations has this name
# The initial one from migration is called "danswer_chunk"
index_name=f"danswer_chunk_{clean_model_name(model_details.model_name)}",
@@ -42,6 +48,42 @@ def create_embedding_model(
return embedding_model
def get_model_id_from_name(
db_session: Session, embedding_provider_name: str
) -> int | None:
query = select(CloudEmbeddingProvider).where(
CloudEmbeddingProvider.name == embedding_provider_name
)
provider = db_session.execute(query).scalars().first()
return provider.id if provider else None
def get_current_db_embedding_provider(
db_session: Session,
) -> ServerCloudEmbeddingProvider | None:
current_embedding_model = EmbeddingModelDetail.from_model(
get_current_db_embedding_model(db_session=db_session)
)
if (
current_embedding_model is None
or current_embedding_model.cloud_provider_id is None
):
return None
embedding_provider = fetch_embedding_provider(
db_session=db_session, provider_id=current_embedding_model.cloud_provider_id
)
if embedding_provider is None:
raise RuntimeError("No embedding provider exists for this model.")
current_embedding_provider = ServerCloudEmbeddingProvider.from_request(
cloud_provider_model=embedding_provider
)
return current_embedding_provider
def get_current_db_embedding_model(db_session: Session) -> EmbeddingModel:
query = (
select(EmbeddingModel)

View File

@@ -2,11 +2,34 @@ from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from danswer.db.models import LLMProvider as LLMProviderModel
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from danswer.server.manage.llm.models import FullLLMProvider
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
def upsert_cloud_embedding_provider(
db_session: Session, provider: CloudEmbeddingProviderCreationRequest
) -> CloudEmbeddingProvider:
existing_provider = (
db_session.query(CloudEmbeddingProviderModel)
.filter_by(name=provider.name)
.first()
)
if existing_provider:
for key, value in provider.dict().items():
setattr(existing_provider, key, value)
else:
new_provider = CloudEmbeddingProviderModel(**provider.dict())
db_session.add(new_provider)
existing_provider = new_provider
db_session.commit()
db_session.refresh(existing_provider)
return CloudEmbeddingProvider.from_request(existing_provider)
def upsert_llm_provider(
db_session: Session, llm_provider: LLMProviderUpsertRequest
) -> FullLLMProvider:
@@ -26,7 +49,6 @@ def upsert_llm_provider(
existing_llm_provider.model_names = llm_provider.model_names
db_session.commit()
return FullLLMProvider.from_model(existing_llm_provider)
# if it does not exist, create a new entry
llm_provider_model = LLMProviderModel(
name=llm_provider.name,
@@ -46,10 +68,26 @@ def upsert_llm_provider(
return FullLLMProvider.from_model(llm_provider_model)
def fetch_existing_embedding_providers(
db_session: Session,
) -> list[CloudEmbeddingProviderModel]:
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
return list(db_session.scalars(select(LLMProviderModel)).all())
def fetch_embedding_provider(
db_session: Session, provider_id: int
) -> CloudEmbeddingProviderModel | None:
return db_session.scalar(
select(CloudEmbeddingProviderModel).where(
CloudEmbeddingProviderModel.id == provider_id
)
)
def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(
@@ -70,6 +108,16 @@ def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider |
return FullLLMProvider.from_model(provider_model)
def remove_embedding_provider(
db_session: Session, embedding_provider_name: str
) -> None:
db_session.execute(
delete(CloudEmbeddingProviderModel).where(
CloudEmbeddingProviderModel.name == embedding_provider_name
)
)
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
db_session.execute(
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)

View File

@@ -130,6 +130,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
chat_folders: Mapped[list["ChatFolder"]] = relationship(
"ChatFolder", back_populates="user"
)
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
# Personas owned by this user
personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user")
@@ -246,6 +247,39 @@ class Persona__Tool(Base):
tool_id: Mapped[int] = mapped_column(ForeignKey("tool.id"), primary_key=True)
class StandardAnswer__StandardAnswerCategory(Base):
__tablename__ = "standard_answer__standard_answer_category"
standard_answer_id: Mapped[int] = mapped_column(
ForeignKey("standard_answer.id"), primary_key=True
)
standard_answer_category_id: Mapped[int] = mapped_column(
ForeignKey("standard_answer_category.id"), primary_key=True
)
class SlackBotConfig__StandardAnswerCategory(Base):
__tablename__ = "slack_bot_config__standard_answer_category"
slack_bot_config_id: Mapped[int] = mapped_column(
ForeignKey("slack_bot_config.id"), primary_key=True
)
standard_answer_category_id: Mapped[int] = mapped_column(
ForeignKey("standard_answer_category.id"), primary_key=True
)
class ChatMessage__StandardAnswer(Base):
__tablename__ = "chat_message__standard_answer"
chat_message_id: Mapped[int] = mapped_column(
ForeignKey("chat_message.id"), primary_key=True
)
standard_answer_id: Mapped[int] = mapped_column(
ForeignKey("standard_answer.id"), primary_key=True
)
"""
Documents/Indexing Tables
"""
@@ -383,6 +417,7 @@ class Connector(Base):
postgresql.JSONB()
)
refresh_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
prune_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
@@ -435,7 +470,7 @@ class Credential(Base):
class EmbeddingModel(Base):
__tablename__ = "embedding_model"
# ID is used also to indicate the order that the models are configured by the admin
id: Mapped[int] = mapped_column(primary_key=True)
model_name: Mapped[str] = mapped_column(String)
model_dim: Mapped[int] = mapped_column(Integer)
@@ -447,6 +482,16 @@ class EmbeddingModel(Base):
)
index_name: Mapped[str] = mapped_column(String)
# New field for cloud provider relationship
cloud_provider_id: Mapped[int | None] = mapped_column(
ForeignKey("embedding_provider.id")
)
cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship(
"CloudEmbeddingProvider",
back_populates="embedding_models",
foreign_keys=[cloud_provider_id],
)
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
"IndexAttempt", back_populates="embedding_model"
)
@@ -466,6 +511,18 @@ class EmbeddingModel(Base):
),
)
def __repr__(self) -> str:
return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\
cloud_provider='{self.cloud_provider.name if self.cloud_provider else 'None'}')>"
@property
def api_key(self) -> str | None:
return self.cloud_provider.api_key if self.cloud_provider else None
@property
def provider_type(self) -> str | None:
return self.cloud_provider.name if self.cloud_provider else None
class IndexAttempt(Base):
"""
@@ -485,6 +542,7 @@ class IndexAttempt(Base):
ForeignKey("credential.id"),
nullable=True,
)
# Some index attempts that run from beginning will still have this as False
# This is only for attempts that are explicitly marked as from the start via
# the run once API
@@ -611,6 +669,10 @@ class SearchDoc(Base):
secondary_owners: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
)
is_internet: Mapped[bool] = mapped_column(Boolean, default=False, nullable=True)
is_relevant: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
relevance_explanation: Mapped[str | None] = mapped_column(String, nullable=True)
chat_messages = relationship(
"ChatMessage",
@@ -660,6 +722,12 @@ class ChatSession(Base):
ForeignKey("chat_folder.id"), nullable=True
)
current_alternate_model: Mapped[str | None] = mapped_column(String, default=None)
slack_thread_id: Mapped[str | None] = mapped_column(
String, nullable=True, default=None
)
# the latest "overrides" specified by the user. These take precedence over
# the attached persona. However, overrides specified directly in the
# `send-message` call will take precedence over these.
@@ -687,7 +755,7 @@ class ChatSession(Base):
"ChatFolder", back_populates="chat_sessions"
)
messages: Mapped[list["ChatMessage"]] = relationship(
"ChatMessage", back_populates="chat_session", cascade="delete"
"ChatMessage", back_populates="chat_session"
)
persona: Mapped["Persona"] = relationship("Persona")
@@ -705,6 +773,11 @@ class ChatMessage(Base):
id: Mapped[int] = mapped_column(primary_key=True)
chat_session_id: Mapped[int] = mapped_column(ForeignKey("chat_session.id"))
alternate_assistant_id = mapped_column(
Integer, ForeignKey("persona.id"), nullable=True
)
parent_message: Mapped[int | None] = mapped_column(Integer, nullable=True)
latest_child_message: Mapped[int | None] = mapped_column(Integer, nullable=True)
message: Mapped[str] = mapped_column(Text)
@@ -733,11 +806,15 @@ class ChatMessage(Base):
chat_session: Mapped[ChatSession] = relationship("ChatSession")
prompt: Mapped[Optional["Prompt"]] = relationship("Prompt")
chat_message_feedbacks: Mapped[list["ChatMessageFeedback"]] = relationship(
"ChatMessageFeedback", back_populates="chat_message"
"ChatMessageFeedback",
back_populates="chat_message",
)
document_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship(
"DocumentRetrievalFeedback", back_populates="chat_message"
"DocumentRetrievalFeedback",
back_populates="chat_message",
)
search_docs: Mapped[list["SearchDoc"]] = relationship(
"SearchDoc",
@@ -748,6 +825,11 @@ class ChatMessage(Base):
"ToolCall",
back_populates="message",
)
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
"StandardAnswer",
secondary=ChatMessage__StandardAnswer.__table__,
back_populates="chat_messages",
)
class ChatFolder(Base):
@@ -784,7 +866,9 @@ class DocumentRetrievalFeedback(Base):
__tablename__ = "document_retrieval_feedback"
id: Mapped[int] = mapped_column(primary_key=True)
chat_message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
chat_message_id: Mapped[int | None] = mapped_column(
ForeignKey("chat_message.id", ondelete="SET NULL"), nullable=True
)
document_id: Mapped[str] = mapped_column(ForeignKey("document.id"))
# How high up this document is in the results, 1 for first
document_rank: Mapped[int] = mapped_column(Integer)
@@ -794,7 +878,9 @@ class DocumentRetrievalFeedback(Base):
)
chat_message: Mapped[ChatMessage] = relationship(
"ChatMessage", back_populates="document_feedbacks"
"ChatMessage",
back_populates="document_feedbacks",
foreign_keys=[chat_message_id],
)
document: Mapped[Document] = relationship(
"Document", back_populates="retrieval_feedbacks"
@@ -805,22 +891,21 @@ class ChatMessageFeedback(Base):
__tablename__ = "chat_feedback"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
chat_message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
chat_message_id: Mapped[int | None] = mapped_column(
ForeignKey("chat_message.id", ondelete="SET NULL"), nullable=True
)
is_positive: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
required_followup: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
feedback_text: Mapped[str | None] = mapped_column(Text, nullable=True)
predefined_feedback: Mapped[str | None] = mapped_column(String, nullable=True)
chat_message: Mapped[ChatMessage] = relationship(
"ChatMessage", back_populates="chat_message_feedbacks"
"ChatMessage",
back_populates="chat_message_feedbacks",
foreign_keys=[chat_message_id],
)
"""
Structures, Organizational, Configurations Tables
"""
class LLMProvider(Base):
__tablename__ = "llm_provider"
@@ -849,6 +934,29 @@ class LLMProvider(Base):
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
class CloudEmbeddingProvider(Base):
__tablename__ = "embedding_provider"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
api_key: Mapped[str | None] = mapped_column(EncryptedString())
default_model_id: Mapped[int | None] = mapped_column(
Integer, ForeignKey("embedding_model.id"), nullable=True
)
embedding_models: Mapped[list["EmbeddingModel"]] = relationship(
"EmbeddingModel",
back_populates="cloud_provider",
foreign_keys="EmbeddingModel.cloud_provider_id",
)
default_model: Mapped["EmbeddingModel"] = relationship(
"EmbeddingModel", foreign_keys=[default_model_id]
)
def __repr__(self) -> str:
return f"<EmbeddingProvider(name='{self.name}')>"
class DocumentSet(Base):
__tablename__ = "document_set"
@@ -928,6 +1036,7 @@ class Tool(Base):
# ID of the tool in the codebase, only applies for in-code tools.
# tools defined via the UI will have this as None
in_code_tool_id: Mapped[str | None] = mapped_column(String, nullable=True)
display_name: Mapped[str] = mapped_column(String, nullable=True)
# OpenAPI scheme for the tool. Only applies to tools defined via the UI.
openapi_schema: Mapped[dict[str, Any] | None] = mapped_column(
@@ -1057,13 +1166,60 @@ class ChannelConfig(TypedDict):
channel_names: list[str]
respond_tag_only: NotRequired[bool] # defaults to False
respond_to_bots: NotRequired[bool] # defaults to False
respond_team_member_list: NotRequired[list[str]]
respond_member_group_list: NotRequired[list[str]]
answer_filters: NotRequired[list[AllowedAnswerFilters]]
# If None then no follow up
# If empty list, follow up with no tags
follow_up_tags: NotRequired[list[str]]
class StandardAnswerCategory(Base):
__tablename__ = "standard_answer_category"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
"StandardAnswer",
secondary=StandardAnswer__StandardAnswerCategory.__table__,
back_populates="categories",
)
slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship(
"SlackBotConfig",
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
back_populates="standard_answer_categories",
)
class StandardAnswer(Base):
__tablename__ = "standard_answer"
id: Mapped[int] = mapped_column(primary_key=True)
keyword: Mapped[str] = mapped_column(String)
answer: Mapped[str] = mapped_column(String)
active: Mapped[bool] = mapped_column(Boolean)
__table_args__ = (
Index(
"unique_keyword_active",
keyword,
active,
unique=True,
postgresql_where=(active == True), # noqa: E712
),
)
categories: Mapped[list[StandardAnswerCategory]] = relationship(
"StandardAnswerCategory",
secondary=StandardAnswer__StandardAnswerCategory.__table__,
back_populates="standard_answers",
)
chat_messages: Mapped[list[ChatMessage]] = relationship(
"ChatMessage",
secondary=ChatMessage__StandardAnswer.__table__,
back_populates="standard_answers",
)
class SlackBotResponseType(str, PyEnum):
QUOTES = "quotes"
CITATIONS = "citations"
@@ -1084,7 +1240,16 @@ class SlackBotConfig(Base):
Enum(SlackBotResponseType, native_enum=False), nullable=False
)
enable_auto_filters: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
persona: Mapped[Persona | None] = relationship("Persona")
standard_answer_categories: Mapped[list[StandardAnswerCategory]] = relationship(
"StandardAnswerCategory",
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
back_populates="slack_bot_configs",
)
class TaskQueueState(Base):
@@ -1364,3 +1529,30 @@ class EmailToExternalUserCache(Base):
)
user = relationship("User")
class UsageReport(Base):
"""This stores metadata about usage reports generated by admin including user who generated
them as well las the period they cover. The actual zip file of the report is stored as a lo
using the PGFileStore
"""
__tablename__ = "usage_reports"
id: Mapped[int] = mapped_column(primary_key=True)
report_name: Mapped[str] = mapped_column(ForeignKey("file_store.file_name"))
# if None, report was auto-generated
requestor_user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id"), nullable=True
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
period_from: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True)
)
period_to: Mapped[datetime.datetime | None] = mapped_column(DateTime(timezone=True))
requestor = relationship("User")
file = relationship("PGFileStore")

View File

@@ -12,8 +12,8 @@ from sqlalchemy import update
from sqlalchemy.orm import Session
from danswer.auth.schemas import UserRole
from danswer.configs.chat_configs import BING_API_KEY
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
from danswer.db.document_set import get_document_sets_by_ids
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import DocumentSet
from danswer.db.models import Persona
@@ -62,19 +62,6 @@ def create_update_persona(
) -> PersonaSnapshot:
"""Higher level function than upsert_persona, although either is valid to use."""
# Permission to actually use these is checked later
document_sets = list(
get_document_sets_by_ids(
document_set_ids=create_persona_request.document_set_ids,
db_session=db_session,
)
)
prompts = list(
get_prompts_by_ids(
prompt_ids=create_persona_request.prompt_ids,
db_session=db_session,
)
)
try:
persona = upsert_persona(
persona_id=persona_id,
@@ -85,9 +72,9 @@ def create_update_persona(
llm_relevance_filter=create_persona_request.llm_relevance_filter,
llm_filter_extraction=create_persona_request.llm_filter_extraction,
recency_bias=create_persona_request.recency_bias,
prompts=prompts,
prompt_ids=create_persona_request.prompt_ids,
tool_ids=create_persona_request.tool_ids,
document_sets=document_sets,
document_set_ids=create_persona_request.document_set_ids,
llm_model_provider_override=create_persona_request.llm_model_provider_override,
llm_model_version_override=create_persona_request.llm_model_version_override,
starter_messages=create_persona_request.starter_messages,
@@ -330,13 +317,13 @@ def upsert_persona(
llm_relevance_filter: bool,
llm_filter_extraction: bool,
recency_bias: RecencyBiasSetting,
prompts: list[Prompt] | None,
document_sets: list[DocumentSet] | None,
llm_model_provider_override: str | None,
llm_model_version_override: str | None,
starter_messages: list[StarterMessage] | None,
is_public: bool,
db_session: Session,
prompt_ids: list[int] | None = None,
document_set_ids: list[int] | None = None,
tool_ids: list[int] | None = None,
persona_id: int | None = None,
default_persona: bool = False,
@@ -356,6 +343,28 @@ def upsert_persona(
if not tools and tool_ids:
raise ValueError("Tools not found")
# Fetch and attach document_sets by IDs
document_sets = None
if document_set_ids is not None:
document_sets = (
db_session.query(DocumentSet)
.filter(DocumentSet.id.in_(document_set_ids))
.all()
)
if not document_sets and document_set_ids:
raise ValueError("document_sets not found")
# Fetch and attach prompts by IDs
prompts = None
if prompt_ids is not None:
prompts = db_session.query(Prompt).filter(Prompt.id.in_(prompt_ids)).all()
if not prompts and prompt_ids:
raise ValueError("prompts not found")
# ensure all specified tools are valid
if tools:
validate_persona_tools(tools)
if persona:
if not default_persona and persona.default_persona:
raise ValueError("Cannot update default persona with non-default.")
@@ -383,10 +392,10 @@ def upsert_persona(
if prompts is not None:
persona.prompts.clear()
persona.prompts = prompts
persona.prompts = prompts or []
if tools is not None:
persona.tools = tools
persona.tools = tools or []
else:
persona = Persona(
@@ -453,6 +462,14 @@ def update_persona_visibility(
db_session.commit()
def validate_persona_tools(tools: list[Tool]) -> None:
for tool in tools:
if tool.name == "InternetSearchTool" and not BING_API_KEY:
raise ValueError(
"Bing API key not found, please contact your Danswer admin to get it added!"
)
def check_user_can_edit_persona(user: User | None, persona: Persona) -> None:
# if user is None, assume that no-auth is turned on
if user is None:
@@ -537,12 +554,22 @@ def get_persona_by_id(
user: User | None,
db_session: Session,
include_deleted: bool = False,
is_for_edit: bool = True, # NOTE: assume true for safety
) -> Persona:
stmt = select(Persona).where(Persona.id == persona_id)
or_conditions = []
# if user is an admin, they should have access to all Personas
if user is not None and user.role != UserRole.ADMIN:
stmt = stmt.where(or_(Persona.user_id == user.id, Persona.user_id.is_(None)))
or_conditions.extend([Persona.user_id == user.id, Persona.user_id.is_(None)])
# if we aren't editing, also give access to all public personas
if not is_for_edit:
or_conditions.append(Persona.is_public.is_(True))
if or_conditions:
stmt = stmt.where(or_(*or_conditions))
if not include_deleted:
stmt = stmt.where(Persona.deleted.is_(False))

View File

@@ -1,3 +1,4 @@
import tempfile
from io import BytesIO
from typing import IO
@@ -6,6 +7,8 @@ from sqlalchemy.orm import Session
from danswer.configs.constants import FileOrigin
from danswer.db.models import PGFileStore
from danswer.file_store.constants import MAX_IN_MEMORY_SIZE
from danswer.file_store.constants import STANDARD_CHUNK_SIZE
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -15,6 +18,25 @@ def get_pg_conn_from_session(db_session: Session) -> connection:
return db_session.connection().connection.connection # type: ignore
def get_pgfilestore_by_file_name(
file_name: str,
db_session: Session,
) -> PGFileStore:
pgfilestore = db_session.query(PGFileStore).filter_by(file_name=file_name).first()
if not pgfilestore:
raise RuntimeError(f"File by name {file_name} does not exist or was deleted")
return pgfilestore
def delete_pgfilestore_by_file_name(
file_name: str,
db_session: Session,
) -> None:
db_session.query(PGFileStore).filter_by(file_name=file_name).delete()
def create_populate_lobj(
content: IO,
db_session: Session,
@@ -26,18 +48,40 @@ def create_populate_lobj(
pg_conn = get_pg_conn_from_session(db_session)
large_object = pg_conn.lobject()
large_object.write(content.read())
# write in multiple chunks to avoid loading the whole file into memory
while True:
chunk = content.read(STANDARD_CHUNK_SIZE)
if not chunk:
break
large_object.write(chunk)
large_object.close()
return large_object.oid
def read_lobj(lobj_oid: int, db_session: Session, mode: str | None = None) -> IO:
def read_lobj(
lobj_oid: int,
db_session: Session,
mode: str | None = None,
use_tempfile: bool = False,
) -> IO:
pg_conn = get_pg_conn_from_session(db_session)
large_object = (
pg_conn.lobject(lobj_oid, mode=mode) if mode else pg_conn.lobject(lobj_oid)
)
return BytesIO(large_object.read())
if use_tempfile:
temp_file = tempfile.SpooledTemporaryFile(max_size=MAX_IN_MEMORY_SIZE)
while True:
chunk = large_object.read(STANDARD_CHUNK_SIZE)
if not chunk:
break
temp_file.write(chunk)
temp_file.seek(0)
return temp_file
else:
return BytesIO(large_object.read())
def delete_lobj_by_id(
@@ -48,6 +92,23 @@ def delete_lobj_by_id(
pg_conn.lobject(lobj_oid).unlink()
def delete_lobj_by_name(
lobj_name: str,
db_session: Session,
) -> None:
try:
pgfilestore = get_pgfilestore_by_file_name(lobj_name, db_session)
except RuntimeError:
logger.info(f"no file with name {lobj_name} found")
return
pg_conn = get_pg_conn_from_session(db_session)
pg_conn.lobject(pgfilestore.lobj_oid).unlink()
delete_pgfilestore_by_file_name(lobj_name, db_session)
db_session.commit()
def upsert_pgfilestore(
file_name: str,
display_name: str | None,
@@ -87,22 +148,3 @@ def upsert_pgfilestore(
db_session.commit()
return pgfilestore
def get_pgfilestore_by_file_name(
file_name: str,
db_session: Session,
) -> PGFileStore:
pgfilestore = db_session.query(PGFileStore).filter_by(file_name=file_name).first()
if not pgfilestore:
raise RuntimeError(f"File by name {file_name} does not exist or was deleted")
return pgfilestore
def delete_pgfilestore_by_file_name(
file_name: str,
db_session: Session,
) -> None:
db_session.query(PGFileStore).filter_by(file_name=file_name).delete()

View File

@@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
from danswer.db.document_set import get_document_sets_by_ids
from danswer.db.models import ChannelConfig
from danswer.db.models import Persona
from danswer.db.models import Persona__DocumentSet
@@ -15,6 +14,7 @@ from danswer.db.models import User
from danswer.db.persona import get_default_prompt
from danswer.db.persona import mark_persona_as_deleted
from danswer.db.persona import upsert_persona
from danswer.db.standard_answer import fetch_standard_answer_categories_by_ids
from danswer.search.enums import RecencyBiasSetting
@@ -42,12 +42,6 @@ def create_slack_bot_persona(
num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
) -> Persona:
"""NOTE: does not commit changes"""
document_sets = list(
get_document_sets_by_ids(
document_set_ids=document_set_ids,
db_session=db_session,
)
)
# create/update persona associated with the slack bot
persona_name = _build_persona_name(channel_names)
@@ -59,10 +53,10 @@ def create_slack_bot_persona(
description="",
num_chunks=num_chunks,
llm_relevance_filter=True,
llm_filter_extraction=True,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.AUTO,
prompts=[default_prompt],
document_sets=document_sets,
prompt_ids=[default_prompt.id],
document_set_ids=document_set_ids,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,
@@ -79,12 +73,25 @@ def insert_slack_bot_config(
persona_id: int | None,
channel_config: ChannelConfig,
response_type: SlackBotResponseType,
standard_answer_category_ids: list[int],
enable_auto_filters: bool,
db_session: Session,
) -> SlackBotConfig:
existing_standard_answer_categories = fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=standard_answer_category_ids,
db_session=db_session,
)
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
raise ValueError(
f"Some or all categories with ids {standard_answer_category_ids} do not exist"
)
slack_bot_config = SlackBotConfig(
persona_id=persona_id,
channel_config=channel_config,
response_type=response_type,
standard_answer_categories=existing_standard_answer_categories,
enable_auto_filters=enable_auto_filters,
)
db_session.add(slack_bot_config)
db_session.commit()
@@ -97,6 +104,8 @@ def update_slack_bot_config(
persona_id: int | None,
channel_config: ChannelConfig,
response_type: SlackBotResponseType,
standard_answer_category_ids: list[int],
enable_auto_filters: bool,
db_session: Session,
) -> SlackBotConfig:
slack_bot_config = db_session.scalar(
@@ -106,6 +115,16 @@ def update_slack_bot_config(
raise ValueError(
f"Unable to find slack bot config with ID {slack_bot_config_id}"
)
existing_standard_answer_categories = fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=standard_answer_category_ids,
db_session=db_session,
)
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
raise ValueError(
f"Some or all categories with ids {standard_answer_category_ids} do not exist"
)
# get the existing persona id before updating the object
existing_persona_id = slack_bot_config.persona_id
@@ -115,6 +134,10 @@ def update_slack_bot_config(
slack_bot_config.persona_id = persona_id
slack_bot_config.channel_config = channel_config
slack_bot_config.response_type = response_type
slack_bot_config.standard_answer_categories = list(
existing_standard_answer_categories
)
slack_bot_config.enable_auto_filters = enable_auto_filters
# if the persona has changed, then clean up the old persona
if persona_id != existing_persona_id and existing_persona_id:

View File

@@ -0,0 +1,239 @@
import string
from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import StandardAnswer
from danswer.db.models import StandardAnswerCategory
from danswer.utils.logger import setup_logger
logger = setup_logger()
def check_category_validity(category_name: str) -> bool:
"""If a category name is too long, it should not be used (it will cause an error in Postgres
as the unique constraint can only apply to entries that are less than 2704 bytes).
Additionally, extremely long categories are not really usable / useful."""
if len(category_name) > 255:
logger.error(
f"Category with name '{category_name}' is too long, cannot be used"
)
return False
return True
def insert_standard_answer_category(
category_name: str, db_session: Session
) -> StandardAnswerCategory:
if not check_category_validity(category_name):
raise ValueError(f"Invalid category name: {category_name}")
standard_answer_category = StandardAnswerCategory(name=category_name)
db_session.add(standard_answer_category)
db_session.commit()
return standard_answer_category
def insert_standard_answer(
keyword: str,
answer: str,
category_ids: list[int],
db_session: Session,
) -> StandardAnswer:
existing_categories = fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=category_ids,
db_session=db_session,
)
if len(existing_categories) != len(category_ids):
raise ValueError(f"Some or all categories with ids {category_ids} do not exist")
standard_answer = StandardAnswer(
keyword=keyword,
answer=answer,
categories=existing_categories,
active=True,
)
db_session.add(standard_answer)
db_session.commit()
return standard_answer
def update_standard_answer(
standard_answer_id: int,
keyword: str,
answer: str,
category_ids: list[int],
db_session: Session,
) -> StandardAnswer:
standard_answer = db_session.scalar(
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
)
if standard_answer is None:
raise ValueError(f"No standard answer with id {standard_answer_id}")
existing_categories = fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=category_ids,
db_session=db_session,
)
if len(existing_categories) != len(category_ids):
raise ValueError(f"Some or all categories with ids {category_ids} do not exist")
standard_answer.keyword = keyword
standard_answer.answer = answer
standard_answer.categories = list(existing_categories)
db_session.commit()
return standard_answer
def remove_standard_answer(
standard_answer_id: int,
db_session: Session,
) -> None:
standard_answer = db_session.scalar(
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
)
if standard_answer is None:
raise ValueError(f"No standard answer with id {standard_answer_id}")
standard_answer.active = False
db_session.commit()
def update_standard_answer_category(
standard_answer_category_id: int,
category_name: str,
db_session: Session,
) -> StandardAnswerCategory:
standard_answer_category = db_session.scalar(
select(StandardAnswerCategory).where(
StandardAnswerCategory.id == standard_answer_category_id
)
)
if standard_answer_category is None:
raise ValueError(
f"No standard answer category with id {standard_answer_category_id}"
)
if not check_category_validity(category_name):
raise ValueError(f"Invalid category name: {category_name}")
standard_answer_category.name = category_name
db_session.commit()
return standard_answer_category
def fetch_standard_answer_category(
standard_answer_category_id: int,
db_session: Session,
) -> StandardAnswerCategory | None:
return db_session.scalar(
select(StandardAnswerCategory).where(
StandardAnswerCategory.id == standard_answer_category_id
)
)
def fetch_standard_answer_categories_by_names(
standard_answer_category_names: list[str],
db_session: Session,
) -> Sequence[StandardAnswerCategory]:
return db_session.scalars(
select(StandardAnswerCategory).where(
StandardAnswerCategory.name.in_(standard_answer_category_names)
)
).all()
def fetch_standard_answer_categories_by_ids(
standard_answer_category_ids: list[int],
db_session: Session,
) -> Sequence[StandardAnswerCategory]:
return db_session.scalars(
select(StandardAnswerCategory).where(
StandardAnswerCategory.id.in_(standard_answer_category_ids)
)
).all()
def fetch_standard_answer_categories(
db_session: Session,
) -> Sequence[StandardAnswerCategory]:
return db_session.scalars(select(StandardAnswerCategory)).all()
def fetch_standard_answer(
standard_answer_id: int,
db_session: Session,
) -> StandardAnswer | None:
return db_session.scalar(
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
)
def find_matching_standard_answers(
id_in: list[int],
query: str,
db_session: Session,
) -> list[StandardAnswer]:
stmt = (
select(StandardAnswer)
.where(StandardAnswer.active.is_(True))
.where(StandardAnswer.id.in_(id_in))
)
possible_standard_answers = db_session.scalars(stmt).all()
matching_standard_answers: list[StandardAnswer] = []
for standard_answer in possible_standard_answers:
# Remove punctuation and split the keyword into individual words
keyword_words = "".join(
char
for char in standard_answer.keyword.lower()
if char not in string.punctuation
).split()
# Remove punctuation and split the query into individual words
query_words = "".join(
char for char in query.lower() if char not in string.punctuation
).split()
# Check if all of the keyword words are in the query words
if all(word in query_words for word in keyword_words):
matching_standard_answers.append(standard_answer)
return matching_standard_answers
def fetch_standard_answers(db_session: Session) -> Sequence[StandardAnswer]:
return db_session.scalars(
select(StandardAnswer).where(StandardAnswer.active.is_(True))
).all()
def create_initial_default_standard_answer_category(db_session: Session) -> None:
default_category_id = 0
default_category_name = "General"
default_category = fetch_standard_answer_category(
standard_answer_category_id=default_category_id,
db_session=db_session,
)
if default_category is not None:
if default_category.name != default_category_name:
raise ValueError(
"DB is not in a valid initial state. "
"Default standard answer category does not have expected name."
)
return
standard_answer_category = StandardAnswerCategory(
id=default_category_id,
name=default_category_name,
)
db_session.add(standard_answer_category)
db_session.commit()

View File

@@ -1,5 +1,6 @@
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -91,9 +92,10 @@ def create_or_add_document_tag_list(
new_tags.append(new_tag)
existing_tag_values.add(tag_value)
logger.debug(
f"Created new tags: {', '.join([f'{tag.tag_key}:{tag.tag_value}' for tag in new_tags])}"
)
if new_tags:
logger.debug(
f"Created new tags: {', '.join([f'{tag.tag_key}:{tag.tag_value}' for tag in new_tags])}"
)
all_tags = existing_tags + new_tags
@@ -106,18 +108,28 @@ def create_or_add_document_tag_list(
def get_tags_by_value_prefix_for_source_types(
tag_key_prefix: str | None,
tag_value_prefix: str | None,
sources: list[DocumentSource] | None,
limit: int | None,
db_session: Session,
) -> list[Tag]:
query = select(Tag)
if tag_value_prefix:
query = query.where(Tag.tag_value.startswith(tag_value_prefix))
if tag_key_prefix or tag_value_prefix:
conditions = []
if tag_key_prefix:
conditions.append(Tag.tag_key.ilike(f"{tag_key_prefix}%"))
if tag_value_prefix:
conditions.append(Tag.tag_value.ilike(f"{tag_value_prefix}%"))
query = query.where(or_(*conditions))
if sources:
query = query.where(Tag.source.in_(sources))
if limit:
query = query.limit(limit)
result = db_session.execute(query)
tags = result.scalars().all()

View File

@@ -26,6 +26,23 @@ def get_latest_task(
return latest_task
def get_latest_task_by_type(
task_name: str,
db_session: Session,
) -> TaskQueueState | None:
stmt = (
select(TaskQueueState)
.where(TaskQueueState.task_name.like(f"%{task_name}%"))
.order_by(desc(TaskQueueState.id))
.limit(1)
)
result = db_session.execute(stmt)
latest_task = result.scalars().first()
return latest_task
def register_task(
task_id: str,
task_name: str,
@@ -66,7 +83,7 @@ def mark_task_finished(
db_session.commit()
def check_live_task_not_timed_out(
def check_task_is_live_and_not_timed_out(
task: TaskQueueState,
db_session: Session,
timeout: int = JOB_TIMEOUT,

View File

@@ -1,15 +1,18 @@
from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.schema import Column
from danswer.db.models import User
def list_users(db_session: Session) -> Sequence[User]:
def list_users(db_session: Session, q: str = "") -> Sequence[User]:
"""List all users. No pagination as of now, as the # of users
is assumed to be relatively small (<< 1 million)"""
return db_session.scalars(select(User)).unique().all()
query = db_session.query(User)
if q:
query = query.filter(Column("email").ilike("%{}%".format(q)))
return query.all()
def get_user_by_email(email: str, db_session: Session) -> User | None:

View File

@@ -6,7 +6,7 @@ from typing import Any
from danswer.access.models import DocumentAccess
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceChunkUncleaned
@dataclass(frozen=True)
@@ -186,7 +186,7 @@ class IdRetrievalCapable(abc.ABC):
min_chunk_ind: int | None,
max_chunk_ind: int | None,
user_access_control_list: list[str] | None = None,
) -> list[InferenceChunk]:
) -> list[InferenceChunkUncleaned]:
"""
Fetch chunk(s) based on document id
@@ -222,7 +222,7 @@ class KeywordCapable(abc.ABC):
time_decay_multiplier: float,
num_to_retrieve: int,
offset: int = 0,
) -> list[InferenceChunk]:
) -> list[InferenceChunkUncleaned]:
"""
Run keyword search and return a list of chunks. Inference chunks are chunks with all of the
information required for query time purposes. For example, some details of the document
@@ -262,7 +262,7 @@ class VectorCapable(abc.ABC):
time_decay_multiplier: float,
num_to_retrieve: int,
offset: int = 0,
) -> list[InferenceChunk]:
) -> list[InferenceChunkUncleaned]:
"""
Run vector/semantic search and return a list of inference chunks.
@@ -298,7 +298,7 @@ class HybridCapable(abc.ABC):
num_to_retrieve: int,
offset: int = 0,
hybrid_alpha: float | None = None,
) -> list[InferenceChunk]:
) -> list[InferenceChunkUncleaned]:
"""
Run hybrid search and return a list of inference chunks.
@@ -348,7 +348,7 @@ class AdminCapable(abc.ABC):
filters: IndexFilters,
num_to_retrieve: int,
offset: int = 0,
) -> list[InferenceChunk]:
) -> list[InferenceChunkUncleaned]:
"""
Run the special search for the admin document explorer page

View File

@@ -91,6 +91,9 @@ schema DANSWER_CHUNK_NAME {
field metadata type string {
indexing: summary | attribute
}
field metadata_suffix type string {
indexing: summary | attribute
}
field doc_updated_at type int {
indexing: summary | attribute
}
@@ -150,43 +153,41 @@ schema DANSWER_CHUNK_NAME {
query(query_embedding) tensor<float>(x[VARIABLE_DIM])
}
# This must be separate function for normalize_linear to work
function vector_score() {
function title_vector_score() {
expression {
# If no title, the full vector score comes from the content embedding
(query(title_content_ratio) * if(attribute(skip_title), closeness(field, embeddings), closeness(field, title_embedding))) +
((1 - query(title_content_ratio)) * closeness(field, embeddings))
}
}
# This must be separate function for normalize_linear to work
function keyword_score() {
expression {
(query(title_content_ratio) * bm25(title)) +
((1 - query(title_content_ratio)) * bm25(content))
#query(title_content_ratio) * if(attribute(skip_title), closeness(field, embeddings), closeness(field, title_embedding))
if(attribute(skip_title), closeness(field, embeddings), closeness(field, title_embedding))
}
}
first-phase {
expression: vector_score
expression: closeness(field, embeddings)
}
# Weighted average between Vector Search and BM-25
# Each is a weighted average between the Title and Content fields
# Finally each doc is boosted by it's user feedback based boost and recency
# If any embedding or index field is missing, it just receives a score of 0
# Assumptions:
# - For a given query + corpus, the BM-25 scores will be relatively similar in distribution
# therefore not normalizing before combining.
# - For documents without title, it gets a score of 0 for that and this is ok as documents
# without any title match should be penalized.
global-phase {
expression {
(
# Weighted Vector Similarity Score
(query(alpha) * normalize_linear(vector_score)) +
(
query(alpha) * (
(query(title_content_ratio) * normalize_linear(title_vector_score))
+
((1 - query(title_content_ratio)) * normalize_linear(closeness(field, embeddings)))
)
)
+
# Weighted Keyword Similarity Score
((1 - query(alpha)) * normalize_linear(keyword_score))
(
(1 - query(alpha)) * (
(query(title_content_ratio) * normalize_linear(bm25(title)))
+
((1 - query(title_content_ratio)) * normalize_linear(bm25(content)))
)
)
)
# Boost based on user feedback
* document_boost
@@ -201,8 +202,6 @@ schema DANSWER_CHUNK_NAME {
bm25(content)
closeness(field, title_embedding)
closeness(field, embeddings)
keyword_score
vector_score
document_boost
recency_bias
closest(embeddings)

View File

@@ -41,6 +41,7 @@ from danswer.configs.constants import HIDDEN
from danswer.configs.constants import INDEX_SEPARATOR
from danswer.configs.constants import METADATA
from danswer.configs.constants import METADATA_LIST
from danswer.configs.constants import METADATA_SUFFIX
from danswer.configs.constants import PRIMARY_OWNERS
from danswer.configs.constants import RECENCY_BIAS
from danswer.configs.constants import SECONDARY_OWNERS
@@ -51,7 +52,6 @@ from danswer.configs.constants import SOURCE_LINKS
from danswer.configs.constants import SOURCE_TYPE
from danswer.configs.constants import TITLE
from danswer.configs.constants import TITLE_EMBEDDING
from danswer.configs.constants import TITLE_SEPARATOR
from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
get_experts_stores_representations,
@@ -64,7 +64,7 @@ from danswer.document_index.vespa.utils import remove_invalid_unicode_chars
from danswer.document_index.vespa.utils import replace_invalid_doc_id_characters
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceChunkUncleaned
from danswer.search.retrieval.search_runner import query_processing
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
from danswer.utils.batching import batch_generator
@@ -119,6 +119,7 @@ def _does_document_exist(
chunk. This checks for whether the chunk exists already in the index"""
doc_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}"
doc_fetch_response = http_client.get(doc_url)
if doc_fetch_response.status_code == 404:
return False
@@ -346,8 +347,10 @@ def _index_vespa_chunk(
TITLE: remove_invalid_unicode_chars(title) if title else None,
SKIP_TITLE_EMBEDDING: not title,
CONTENT: remove_invalid_unicode_chars(chunk.content),
# This duplication of `content` is needed for keyword highlighting :(
CONTENT_SUMMARY: remove_invalid_unicode_chars(chunk.content),
# This duplication of `content` is needed for keyword highlighting
# Note that it's not exactly the same as the actual content
# which contains the title prefix and metadata suffix
CONTENT_SUMMARY: remove_invalid_unicode_chars(chunk.content_summary),
SOURCE_TYPE: str(document.source.value),
SOURCE_LINKS: json.dumps(chunk.source_links),
SEMANTIC_IDENTIFIER: remove_invalid_unicode_chars(document.semantic_identifier),
@@ -355,6 +358,7 @@ def _index_vespa_chunk(
METADATA: json.dumps(document.metadata),
# Save as a list for efficient extraction as an Attribute
METADATA_LIST: chunk.source_document.get_metadata_str_attributes(),
METADATA_SUFFIX: chunk.metadata_suffix,
EMBEDDINGS: embeddings_name_vector_map,
TITLE_EMBEDDING: chunk.title_embedding,
BOOST: chunk.boost,
@@ -559,7 +563,9 @@ def _process_dynamic_summary(
return processed_summary
def _vespa_hit_to_inference_chunk(hit: dict[str, Any]) -> InferenceChunk:
def _vespa_hit_to_inference_chunk(
hit: dict[str, Any], null_score: bool = False
) -> InferenceChunkUncleaned:
fields = cast(dict[str, Any], hit["fields"])
# parse fields that are stored as strings, but are really json / datetime
@@ -582,19 +588,6 @@ def _vespa_hit_to_inference_chunk(hit: dict[str, Any]) -> InferenceChunk:
f"Chunk with blurb: {fields.get(BLURB, 'Unknown')[:50]}... has no Semantic Identifier"
)
# Remove the title from the first chunk as every chunk already included
# its semantic identifier for LLM
content = fields[CONTENT]
if fields[CHUNK_ID] == 0:
parts = content.split(TITLE_SEPARATOR, maxsplit=1)
content = parts[1] if len(parts) > 1 and "\n" not in parts[0] else content
# User ran into this, not sure why this could happen, error checking here
blurb = fields.get(BLURB)
if not blurb:
logger.error(f"Chunk with id {fields.get(semantic_identifier)} ")
blurb = ""
source_links = fields.get(SOURCE_LINKS, {})
source_links_dict_unprocessed = (
json.loads(source_links) if isinstance(source_links, str) else source_links
@@ -604,29 +597,33 @@ def _vespa_hit_to_inference_chunk(hit: dict[str, Any]) -> InferenceChunk:
for k, v in cast(dict[str, str], source_links_dict_unprocessed).items()
}
return InferenceChunk(
return InferenceChunkUncleaned(
chunk_id=fields[CHUNK_ID],
blurb=blurb,
content=content,
blurb=fields.get(BLURB, ""), # Unused
content=fields[CONTENT], # Includes extra title prefix and metadata suffix
source_links=source_links_dict,
section_continuation=fields[SECTION_CONTINUATION],
document_id=fields[DOCUMENT_ID],
source_type=fields[SOURCE_TYPE],
title=fields.get(TITLE),
semantic_identifier=fields[SEMANTIC_IDENTIFIER],
boost=fields.get(BOOST, 1),
recency_bias=fields.get("matchfeatures", {}).get(RECENCY_BIAS, 1.0),
score=hit.get("relevance", 0),
score=None if null_score else hit.get("relevance", 0),
hidden=fields.get(HIDDEN, False),
primary_owners=fields.get(PRIMARY_OWNERS),
secondary_owners=fields.get(SECONDARY_OWNERS),
metadata=metadata,
metadata_suffix=fields.get(METADATA_SUFFIX),
match_highlights=match_highlights,
updated_at=updated_at,
)
@retry(tries=3, delay=1, backoff=2)
def _query_vespa(query_params: Mapping[str, str | int | float]) -> list[InferenceChunk]:
def _query_vespa(
query_params: Mapping[str, str | int | float]
) -> list[InferenceChunkUncleaned]:
if "query" in query_params and not cast(str, query_params["query"]).strip():
raise ValueError("No/empty query received")
@@ -681,16 +678,6 @@ def _query_vespa(query_params: Mapping[str, str | int | float]) -> list[Inferenc
return inference_chunks
@retry(tries=3, delay=1, backoff=2)
def _inference_chunk_by_vespa_id(vespa_id: str, index_name: str) -> InferenceChunk:
res = requests.get(
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_id}"
)
res.raise_for_status()
return _vespa_hit_to_inference_chunk(res.json())
def in_memory_zip_from_file_bytes(file_contents: dict[str, bytes]) -> BinaryIO:
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zipf:
@@ -735,6 +722,7 @@ class VespaIndex(DocumentIndex):
f"{SOURCE_TYPE}, "
f"{SOURCE_LINKS}, "
f"{SEMANTIC_IDENTIFIER}, "
f"{TITLE}, "
f"{SECTION_CONTINUATION}, "
f"{BOOST}, "
f"{HIDDEN}, "
@@ -742,6 +730,7 @@ class VespaIndex(DocumentIndex):
f"{PRIMARY_OWNERS}, "
f"{SECONDARY_OWNERS}, "
f"{METADATA}, "
f"{METADATA_SUFFIX}, "
f"{CONTENT_SUMMARY} "
f"from {{index_name}} where "
)
@@ -977,7 +966,7 @@ class VespaIndex(DocumentIndex):
min_chunk_ind: int | None,
max_chunk_ind: int | None,
user_access_control_list: list[str] | None = None,
) -> list[InferenceChunk]:
) -> list[InferenceChunkUncleaned]:
document_id = replace_invalid_doc_id_characters(document_id)
vespa_chunks = _get_vespa_chunks_by_document_id(
@@ -992,7 +981,8 @@ class VespaIndex(DocumentIndex):
return []
inference_chunks = [
_vespa_hit_to_inference_chunk(chunk) for chunk in vespa_chunks
_vespa_hit_to_inference_chunk(chunk, null_score=True)
for chunk in vespa_chunks
]
inference_chunks.sort(key=lambda chunk: chunk.chunk_id)
return inference_chunks
@@ -1005,7 +995,7 @@ class VespaIndex(DocumentIndex):
num_to_retrieve: int = NUM_RETURNED_HITS,
offset: int = 0,
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
) -> list[InferenceChunk]:
) -> list[InferenceChunkUncleaned]:
# IMPORTANT: THIS FUNCTION IS NOT UP TO DATE, DOES NOT WORK CORRECTLY
vespa_where_clauses = _build_vespa_filters(filters)
yql = (
@@ -1042,7 +1032,7 @@ class VespaIndex(DocumentIndex):
offset: int = 0,
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
) -> list[InferenceChunk]:
) -> list[InferenceChunkUncleaned]:
# IMPORTANT: THIS FUNCTION IS NOT UP TO DATE, DOES NOT WORK CORRECTLY
vespa_where_clauses = _build_vespa_filters(filters)
yql = (
@@ -1086,7 +1076,7 @@ class VespaIndex(DocumentIndex):
title_content_ratio: float | None = TITLE_CONTENT_RATIO,
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
) -> list[InferenceChunk]:
) -> list[InferenceChunkUncleaned]:
vespa_where_clauses = _build_vespa_filters(filters)
# Needs to be at least as much as the value set in Vespa schema config
target_hits = max(10 * num_to_retrieve, 1000)
@@ -1130,7 +1120,7 @@ class VespaIndex(DocumentIndex):
filters: IndexFilters,
num_to_retrieve: int = NUM_RETURNED_HITS,
offset: int = 0,
) -> list[InferenceChunk]:
) -> list[InferenceChunkUncleaned]:
vespa_where_clauses = _build_vespa_filters(filters, include_hidden=True)
yql = (
VespaIndex.yql_base.format(index_name=self.index_name)

View File

@@ -0,0 +1,8 @@
from enum import Enum
class HtmlBasedConnectorTransformLinksStrategy(str, Enum):
# remove links entirely
STRIP = "strip"
# turn HTML links into markdown links
MARKDOWN = "markdown"

View File

@@ -3,6 +3,7 @@ import json
import os
import re
import zipfile
from collections.abc import Callable
from collections.abc import Iterator
from email.parser import Parser as EmailParser
from pathlib import Path
@@ -16,6 +17,7 @@ import pptx # type: ignore
from pypdf import PdfReader
from pypdf.errors import PdfStreamError
from danswer.configs.constants import DANSWER_METADATA_FILENAME
from danswer.file_processing.html_utils import parse_html_page_basic
from danswer.utils.logger import setup_logger
@@ -64,6 +66,16 @@ def check_file_ext_is_valid(ext: str) -> bool:
return ext in VALID_FILE_EXTENSIONS
def is_text_file(file: IO[bytes]) -> bool:
"""
checks if the first 1024 bytes only contain printable or whitespace characters
if it does, then we say its a plaintext file
"""
raw_data = file.read(1024)
text_chars = bytearray({7, 8, 9, 10, 12, 13, 27} | set(range(0x20, 0x100)) - {0x7F})
return all(c in text_chars for c in raw_data)
def detect_encoding(file: IO[bytes]) -> str:
raw_data = file.read(50000)
encoding = chardet.detect(raw_data)["encoding"] or "utf-8"
@@ -88,7 +100,7 @@ def load_files_from_zip(
with zipfile.ZipFile(zip_file_io, "r") as zip_file:
zip_metadata = {}
try:
metadata_file_info = zip_file.getinfo(".danswer_metadata.json")
metadata_file_info = zip_file.getinfo(DANSWER_METADATA_FILENAME)
with zip_file.open(metadata_file_info, "r") as metadata_file:
try:
zip_metadata = json.load(metadata_file)
@@ -96,18 +108,19 @@ def load_files_from_zip(
# convert list of dicts to dict of dicts
zip_metadata = {d["filename"]: d for d in zip_metadata}
except json.JSONDecodeError:
logger.warn("Unable to load .danswer_metadata.json")
logger.warn(f"Unable to load {DANSWER_METADATA_FILENAME}")
except KeyError:
logger.info("No .danswer_metadata.json file")
logger.info(f"No {DANSWER_METADATA_FILENAME} file")
for file_info in zip_file.infolist():
with zip_file.open(file_info.filename, "r") as file:
if ignore_dirs and file_info.is_dir():
continue
if ignore_macos_resource_fork_files and is_macos_resource_fork_file(
file_info.filename
):
if (
ignore_macos_resource_fork_files
and is_macos_resource_fork_file(file_info.filename)
) or file_info.filename == DANSWER_METADATA_FILENAME:
continue
yield file_info, file, zip_metadata.get(file_info.filename, {})
@@ -259,37 +272,32 @@ def extract_file_text(
file: IO[Any],
break_on_unprocessable: bool = True,
) -> str:
if not file_name:
return file_io_to_text(file)
extension_to_function: dict[str, Callable[[IO[Any]], str]] = {
".pdf": pdf_to_text,
".docx": docx_to_text,
".pptx": pptx_to_text,
".xlsx": xlsx_to_text,
".eml": eml_to_text,
".epub": epub_to_text,
".html": parse_html_page_basic,
}
extension = get_file_ext(file_name)
if not check_file_ext_is_valid(extension):
def _process_file() -> str:
if file_name:
extension = get_file_ext(file_name)
if check_file_ext_is_valid(extension):
return extension_to_function.get(extension, file_io_to_text)(file)
# Either the file somehow has no name or the extension is not one that we are familiar with
if is_text_file(file):
return file_io_to_text(file)
raise ValueError("Unknown file extension and unknown text encoding")
try:
return _process_file()
except Exception as e:
if break_on_unprocessable:
raise RuntimeError(f"Unprocessable file type: {file_name}")
else:
logger.warning(f"Unprocessable file type: {file_name}")
return ""
if extension == ".pdf":
return pdf_to_text(file=file)
elif extension == ".docx":
return docx_to_text(file)
elif extension == ".pptx":
return pptx_to_text(file)
elif extension == ".xlsx":
return xlsx_to_text(file)
elif extension == ".eml":
return eml_to_text(file)
elif extension == ".epub":
return epub_to_text(file)
elif extension == ".html":
return parse_html_page_basic(file)
else:
return file_io_to_text(file)
raise RuntimeError(f"Failed to process file: {str(e)}") from e
logger.warning(f"Failed to process file: {str(e)}")
return ""

View File

@@ -5,8 +5,10 @@ from typing import IO
import bs4
from danswer.configs.app_configs import HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY
from danswer.configs.app_configs import WEB_CONNECTOR_IGNORED_CLASSES
from danswer.configs.app_configs import WEB_CONNECTOR_IGNORED_ELEMENTS
from danswer.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
MINTLIFY_UNWANTED = ["sticky", "hidden"]
@@ -32,6 +34,19 @@ def strip_newlines(document: str) -> str:
return re.sub(r"[\n\r]+", " ", document)
def format_element_text(element_text: str, link_href: str | None) -> str:
element_text_no_newlines = strip_newlines(element_text)
if (
not link_href
or HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY
== HtmlBasedConnectorTransformLinksStrategy.STRIP
):
return element_text_no_newlines
return f"[{element_text_no_newlines}]({link_href})"
def format_document_soup(
document: bs4.BeautifulSoup, table_cell_separator: str = "\t"
) -> str:
@@ -49,6 +64,8 @@ def format_document_soup(
verbatim_output = 0
in_table = False
last_added_newline = False
link_href: str | None = None
for e in document.descendants:
verbatim_output -= 1
if isinstance(e, bs4.element.NavigableString):
@@ -71,7 +88,7 @@ def format_document_soup(
content_to_add = (
element_text
if verbatim_output > 0
else strip_newlines(element_text)
else format_element_text(element_text, link_href)
)
# Don't join separate elements without any spacing
@@ -98,7 +115,14 @@ def format_document_soup(
elif in_table:
# don't handle other cases while in table
pass
elif e.name == "a":
href_value = e.get("href", None)
# mostly for typing, having multiple hrefs is not valid HTML
link_href = (
href_value[0] if isinstance(href_value, list) else href_value
)
elif e.name == "/a":
link_href = None
elif e.name in ["p", "div"]:
if not list_element_start:
text += "\n"

View File

@@ -0,0 +1,2 @@
MAX_IN_MEMORY_SIZE = 30 * 1024 * 1024 # 30MB
STANDARD_CHUNK_SIZE = 10 * 1024 * 1024 # 10MB chunks

View File

@@ -5,6 +5,7 @@ from typing import IO
from sqlalchemy.orm import Session
from danswer.configs.constants import FileOrigin
from danswer.db.models import PGFileStore
from danswer.db.pg_file_store import create_populate_lobj
from danswer.db.pg_file_store import delete_lobj_by_id
from danswer.db.pg_file_store import delete_pgfilestore_by_file_name
@@ -26,6 +27,7 @@ class FileStore(ABC):
display_name: str | None,
file_origin: FileOrigin,
file_type: str,
file_metadata: dict | None = None,
) -> None:
"""
Save a file to the blob store
@@ -41,12 +43,17 @@ class FileStore(ABC):
raise NotImplementedError
@abstractmethod
def read_file(self, file_name: str, mode: str | None) -> IO:
def read_file(
self, file_name: str, mode: str | None, use_tempfile: bool = False
) -> IO:
"""
Read the content of a given file by the name
Parameters:
- file_name: Name of file to read
- mode: Mode to open the file (e.g. 'b' for binary)
- use_tempfile: Whether to use a temporary file to store the contents
in order to avoid loading the entire file into memory
Returns:
Contents of the file and metadata dict
@@ -73,6 +80,7 @@ class PostgresBackedFileStore(FileStore):
display_name: str | None,
file_origin: FileOrigin,
file_type: str,
file_metadata: dict | None = None,
) -> None:
try:
# The large objects in postgres are saved as special objects can be listed with
@@ -85,20 +93,33 @@ class PostgresBackedFileStore(FileStore):
file_type=file_type,
lobj_oid=obj_id,
db_session=self.db_session,
file_metadata=file_metadata,
)
self.db_session.commit()
except Exception:
self.db_session.rollback()
raise
def read_file(self, file_name: str, mode: str | None = None) -> IO:
def read_file(
self, file_name: str, mode: str | None = None, use_tempfile: bool = False
) -> IO:
file_record = get_pgfilestore_by_file_name(
file_name=file_name, db_session=self.db_session
)
return read_lobj(
lobj_oid=file_record.lobj_oid, db_session=self.db_session, mode=mode
lobj_oid=file_record.lobj_oid,
db_session=self.db_session,
mode=mode,
use_tempfile=use_tempfile,
)
def read_file_record(self, file_name: str) -> PGFileStore:
file_record = get_pgfilestore_by_file_name(
file_name=file_name, db_session=self.db_session
)
return file_record
def delete_file(self, file_name: str) -> None:
try:
file_record = get_pgfilestore_by_file_name(

View File

@@ -3,12 +3,16 @@ from collections.abc import Callable
from typing import TYPE_CHECKING
from danswer.configs.app_configs import BLURB_SIZE
from danswer.configs.app_configs import CHUNK_OVERLAP
from danswer.configs.app_configs import MINI_CHUNK_SIZE
from danswer.configs.app_configs import SKIP_METADATA_IN_CHUNK
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MAX_CHUNK_TITLE_LEN
from danswer.configs.constants import RETURN_SEPARATOR
from danswer.configs.constants import SECTION_SEPARATOR
from danswer.configs.constants import TITLE_SEPARATOR
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
get_metadata_keys_to_ignore,
)
from danswer.connectors.models import Document
from danswer.indexing.models import DocAwareChunk
from danswer.search.search_nlp_models import get_default_tokenizer
@@ -19,6 +23,14 @@ if TYPE_CHECKING:
from transformers import AutoTokenizer # type:ignore
# Not supporting overlaps, we need a clean combination of chunks and it is unclear if overlaps
# actually help quality at all
CHUNK_OVERLAP = 0
# Fairly arbitrary numbers but the general concept is we don't want the title/metadata to
# overwhelm the actual contents of the chunk
MAX_METADATA_PERCENTAGE = 0.25
CHUNK_MIN_CONTENT = 256
logger = setup_logger()
ChunkFunc = Callable[[Document], list[DocAwareChunk]]
@@ -44,6 +56,8 @@ def chunk_large_section(
chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
chunk_overlap: int = CHUNK_OVERLAP,
blurb_size: int = BLURB_SIZE,
title_prefix: str = "",
metadata_suffix: str = "",
) -> list[DocAwareChunk]:
from llama_index.text_splitter import SentenceSplitter
@@ -60,30 +74,69 @@ def chunk_large_section(
source_document=document,
chunk_id=start_chunk_id + chunk_ind,
blurb=blurb,
content=chunk_str,
content=f"{title_prefix}{chunk_str}{metadata_suffix}",
content_summary=chunk_str,
source_links={0: section_link_text},
section_continuation=(chunk_ind != 0),
metadata_suffix=metadata_suffix,
)
for chunk_ind, chunk_str in enumerate(split_texts)
]
return chunks
def _get_metadata_suffix_for_document_index(
metadata: dict[str, str | list[str]]
) -> str:
if not metadata:
return ""
metadata_str = "Metadata:\n"
for key, value in metadata.items():
if key in get_metadata_keys_to_ignore():
continue
value_str = ", ".join(value) if isinstance(value, list) else value
metadata_str += f"\t{key} - {value_str}\n"
return metadata_str.strip()
def chunk_document(
document: Document,
chunk_tok_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
subsection_overlap: int = CHUNK_OVERLAP,
blurb_size: int = BLURB_SIZE,
include_metadata: bool = not SKIP_METADATA_IN_CHUNK,
) -> list[DocAwareChunk]:
title = document.get_title_for_document_index()
title_prefix = title.replace("\n", " ") + TITLE_SEPARATOR if title else ""
tokenizer = get_default_tokenizer()
title = document.get_title_for_document_index()
title_prefix = f"{title[:MAX_CHUNK_TITLE_LEN]}{RETURN_SEPARATOR}" if title else ""
title_tokens = len(tokenizer.tokenize(title_prefix))
metadata_suffix = ""
metadata_tokens = 0
if include_metadata:
metadata = _get_metadata_suffix_for_document_index(document.metadata)
metadata_suffix = RETURN_SEPARATOR + metadata if metadata else ""
metadata_tokens = len(tokenizer.tokenize(metadata_suffix))
if metadata_tokens >= chunk_tok_size * MAX_METADATA_PERCENTAGE:
metadata_suffix = ""
metadata_tokens = 0
content_token_limit = chunk_tok_size - title_tokens - metadata_tokens
# If there is not enough context remaining then just index the chunk with no prefix/suffix
if content_token_limit <= CHUNK_MIN_CONTENT:
content_token_limit = chunk_tok_size
title_prefix = ""
metadata_suffix = ""
chunks: list[DocAwareChunk] = []
link_offsets: dict[int, str] = {}
chunk_text = ""
for ind, section in enumerate(document.sections):
section_text = title_prefix + section.text if ind == 0 else section.text
for section in document.sections:
section_text = section.text
section_link_text = section.link or ""
section_tok_length = len(tokenizer.tokenize(section_text))
@@ -92,16 +145,18 @@ def chunk_document(
# Large sections are considered self-contained/unique therefore they start a new chunk and are not concatenated
# at the end by other sections
if section_tok_length > chunk_tok_size:
if section_tok_length > content_token_limit:
if chunk_text:
chunks.append(
DocAwareChunk(
source_document=document,
chunk_id=len(chunks),
blurb=extract_blurb(chunk_text, blurb_size),
content=chunk_text,
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
content_summary=chunk_text,
source_links=link_offsets,
section_continuation=False,
metadata_suffix=metadata_suffix,
)
)
link_offsets = {}
@@ -113,9 +168,11 @@ def chunk_document(
document=document,
start_chunk_id=len(chunks),
tokenizer=tokenizer,
chunk_size=chunk_tok_size,
chunk_size=content_token_limit,
chunk_overlap=subsection_overlap,
blurb_size=blurb_size,
title_prefix=title_prefix,
metadata_suffix=metadata_suffix,
)
chunks.extend(large_section_chunks)
continue
@@ -125,7 +182,7 @@ def chunk_document(
current_tok_length
+ len(tokenizer.tokenize(SECTION_SEPARATOR))
+ section_tok_length
<= chunk_tok_size
<= content_token_limit
):
chunk_text += (
SECTION_SEPARATOR + section_text if chunk_text else section_text
@@ -137,9 +194,11 @@ def chunk_document(
source_document=document,
chunk_id=len(chunks),
blurb=extract_blurb(chunk_text, blurb_size),
content=chunk_text,
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
content_summary=chunk_text,
source_links=link_offsets,
section_continuation=False,
metadata_suffix=metadata_suffix,
)
)
link_offsets = {0: section_link_text}
@@ -153,9 +212,11 @@ def chunk_document(
source_document=document,
chunk_id=len(chunks),
blurb=extract_blurb(chunk_text, blurb_size),
content=chunk_text,
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
content_summary=chunk_text,
source_links=link_offsets,
section_continuation=False,
metadata_suffix=metadata_suffix,
)
)
return chunks
@@ -164,6 +225,9 @@ def chunk_document(
def split_chunk_text_into_mini_chunks(
chunk_text: str, mini_chunk_size: int = MINI_CHUNK_SIZE
) -> list[str]:
"""The minichunks won't all have the title prefix or metadata suffix
It could be a significant percentage of every minichunk so better to not include it
"""
from llama_index.text_splitter import SentenceSplitter
token_count_func = get_default_tokenizer().tokenize

Some files were not shown because too many files have changed in this diff Show More