Compare commits

...

373 Commits

Author SHA1 Message Date
Yuhong Sun
4293543a6a k 2024-07-20 16:48:05 -07:00
Yuhong Sun
e95bfa0e0b Suffix Test (#1880) 2024-07-20 15:54:55 -07:00
Yuhong Sun
4848b5f1de Suffix Edits (#1878) 2024-07-20 13:59:14 -07:00
Yuhong Sun
7ba5c434fa Missing Comma (#1877) 2024-07-19 22:15:45 -07:00
Yuhong Sun
59bf5ba848 File Connector Metadata (#1876) 2024-07-19 20:45:18 -07:00
Weves
f66c33380c Improve widget README 2024-07-19 20:21:07 -07:00
Weves
115650ce9f Add example widget code 2024-07-19 20:14:52 -07:00
Weves
7aa3602fca Fix black 2024-07-19 18:55:09 -07:00
Weves
864c552a17 Fix UT 2024-07-19 18:55:09 -07:00
Brent Kwok
07b2ed3d8f Fix HTTP 422 error for api_inference_sample.py (#1868) 2024-07-19 18:54:43 -07:00
Yuhong Sun
38290057f2 Search Eval (#1873) 2024-07-19 16:48:58 -07:00
Weves
2344edf158 Change default login time to 7 days 2024-07-19 13:58:50 -07:00
versecafe
86d1804eb0 Add GPT-4o-Mini & fix a missing gpt-4o 2024-07-19 12:10:27 -07:00
pablodanswer
1ebae50d0c minor udpate 2024-07-19 10:53:28 -07:00
Weves
a9fbaa396c Stop building on every PR 2024-07-19 10:21:19 -07:00
pablodanswer
27d5f69427 udpate to headers (#1864) 2024-07-19 08:38:54 -07:00
pablodanswer
5d98421ae8 show "analysis" (#1863) 2024-07-18 18:18:36 -07:00
Kevin Shi
6b561b8ca9 Add config to skip zendesk article labels 2024-07-18 18:00:51 -07:00
pablodanswer
2dc7e64dd7 fix internet search icons / text + assistants tab (#1862) 2024-07-18 16:15:19 -07:00
Yuhong Sun
5230f7e22f Enforce Disable GenAI if set (#1860) 2024-07-18 13:25:55 -07:00
hagen-danswer
a595d43ae3 Fixed deleting toolcall by message 2024-07-18 12:52:28 -07:00
Yuhong Sun
ee561f42ff Cleaner Layout (#1857) 2024-07-18 11:13:16 -07:00
Yuhong Sun
f00b3d76b3 Touchup NoOp (#1856) 2024-07-18 08:44:27 -07:00
Yuhong Sun
e4984153c0 Touchups (#1855) 2024-07-17 23:47:10 -07:00
pablodanswer
87fadb07ea COMPLETE USER EXPERIENCE OVERHAUL (#1822) 2024-07-17 19:44:21 -07:00
pablodanswer
2b07c102f9 fix discourse connector rate limiting + topic fetching (#1820) 2024-07-17 14:57:40 -07:00
hagen-danswer
e93de602c3 Use SHA instead of branch and save more data (#1850) 2024-07-17 14:56:24 -07:00
hagen-danswer
1c77395503 Fixed llm_indices from document search api (#1853) 2024-07-17 14:52:49 -07:00
Victorivus
cdf6089b3e Fix bug XML files in chat (#1804) 2024-07-17 08:09:40 -07:00
pablodanswer
d01f46af2b fix search doc bug (#1851) 2024-07-16 15:27:04 -07:00
hagen-danswer
b83f435bb0 Catch dropped eval questions and added multiprocessing (#1849) 2024-07-16 12:33:02 -07:00
hagen-danswer
25b3dacaba Seperated model caching volumes (#1845) 2024-07-15 15:32:04 -07:00
hagen-danswer
a1e638a73d Improved eval logging and stability (#1843) 2024-07-15 14:58:45 -07:00
Yuhong Sun
bd1e0c5969 Add Enum File (#1842) 2024-07-15 09:13:27 -07:00
Yuhong Sun
4d295ab97d Model Server Logging (#1839) 2024-07-15 09:00:27 -07:00
Weves
6fe3eeaa48 Fix model serer startup 2024-07-14 23:33:58 -07:00
Chris Weaver
078d5defbb Update Slack link in README.md 2024-07-14 16:50:48 -07:00
Weves
0d52e99bd4 Improve confluence rate limiting 2024-07-14 16:40:45 -07:00
hagen-danswer
1b864a00e4 Added support for multiple Eval Pipeline UIs (#1830) 2024-07-14 15:16:20 -07:00
Weves
dae4f6a0bd Fix latency caused by large numbers of tags 2024-07-14 14:21:07 -07:00
Yuhong Sun
f63d0ca3ad Title Truncation Logic (#1828) 2024-07-14 13:54:36 -07:00
Yuhong Sun
da31da33e7 Fix Title for docs without (#1827) 2024-07-14 13:51:11 -07:00
Yuhong Sun
56b175f597 Fix Sitemap Robo (#1826) 2024-07-14 13:29:26 -07:00
Zoltan Szabo
1b311d092e Try to find the sitemap for a given site (#1538) 2024-07-14 13:24:10 -07:00
Moshe Zada
6ee1292757 Fix semantic id for web pdfs (#1823) 2024-07-14 11:38:11 -07:00
Yuhong Sun
017af052be Global Tokenizer Fix (#1825) 2024-07-14 11:37:10 -07:00
pablodanswer
e7f81d1688 add third party embedding models (#1818) 2024-07-14 10:19:53 -07:00
Weves
b6bd818e60 Fix user groups page when a persona is deleted 2024-07-13 15:35:50 -07:00
hagen-danswer
36da2e4b27 Fixed slack groups (#1814)
* Simplified slackbot response groups and fixed need more help bug

* mypy fixes

* added exceptions for the couldnt find passthrough arrays
2024-07-13 22:34:35 +00:00
pablodanswer
c7af6a4601 add new standard answer test endpoint (#1789) 2024-07-12 10:06:30 -07:00
Yuhong Sun
e90c66c1b6 Include Titles in Chunks (#1817) 2024-07-12 09:42:24 -07:00
hagen-danswer
8c312482c1 fixed id retrieval from zip metadata (#1813) 2024-07-11 20:38:12 -07:00
Weves
e50820e65e Remove Internet 'Connector' that mistakenly appears on the Add Connector page 2024-07-11 18:00:59 -07:00
hagen-danswer
991ee79e47 some qol improvements for search pipeline (#1809) 2024-07-11 17:42:11 -07:00
hagen-danswer
3e645a510e Fix slack error logging (#1800) 2024-07-11 08:31:48 -07:00
Yuhong Sun
08c6e821e7 Merge Sections Logic (#1801) 2024-07-10 20:14:02 -07:00
hagen-danswer
47a550221f slackbot doesnt respond without citations/quotes (#1798)
* slackbot doesnt respond without citations/quotes

fixed logical issues

fixed dict logic

* added slackbot shim for the llm source/time feature

* mypy fixes

* slackbot doesnt respond without citations/quotes

fixed logical issues

fixed dict logic

* Update handle_regular_answer.py

* added bypass_filter check

* final fixes
2024-07-11 00:18:26 +00:00
Weves
511f619212 Add content to /document-search response 2024-07-10 15:44:58 -07:00
Varun Gaur
6c51f001dc Confluence Connector to Sync Child pages only (#1629)
---------

Co-authored-by: Varun Gaur <vgaur@roku.com>
Co-authored-by: hagen-danswer <hagen@danswer.ai>
Co-authored-by: pablodanswer <pablo@danswer.ai>
2024-07-10 14:17:03 -07:00
pablodanswer
09a11b5e1a Fix citations + unit tests (#1760) 2024-07-10 10:05:20 -07:00
pablodanswer
aa0f7abdac add basic table wrapping (#1791) 2024-07-09 19:14:41 -07:00
Yuhong Sun
7c8f8dba17 Break the Danswer LLM logging from LiteLLM Verbose (#1795) 2024-07-09 18:18:29 -07:00
Yuhong Sun
39982e5fdc Info propagating to allow Chunk Merging (#1794) 2024-07-09 18:15:07 -07:00
pablodanswer
5e0de111f9 fix wrapping in error hover connector (#1790) 2024-07-09 11:54:35 -07:00
pablodanswer
727d80f168 fix gpt-4o image issue (#1786) 2024-07-08 23:07:53 +00:00
rashad-danswer
146f85936b Internet Search Tool (#1666)
---------

Co-authored-by: Weves <chrisweaver101@gmail.com>
2024-07-06 18:01:24 -07:00
Chris Weaver
e06f8a0a4b Standard Answers (#1753)
---------

Co-authored-by: druhinsgoel <druhin@danswer.ai>
2024-07-06 16:11:11 -07:00
Yuhong Sun
f0888f2f61 Eval Script Incremental Write (#1784) 2024-07-06 15:43:40 -07:00
Yuhong Sun
d35d7ee833 Evaluation Pipeline Touchup (#1783) 2024-07-06 13:17:05 -07:00
Yuhong Sun
c5bb3fde94 Ignore Eval Files (#1782) 2024-07-06 12:15:03 -07:00
Yuhong Sun
79190030a5 New Env File for Eval (#1781) 2024-07-06 12:07:31 -07:00
Yuhong Sun
8e8f262ed3 Docker Compose Eval Pipeline Cleanup (#1780) 2024-07-06 12:04:57 -07:00
hagen-danswer
ac14369716 Added search quality testing pipeline (#1774) 2024-07-06 11:51:50 -07:00
Weves
de4d8e9a65 Fix shared chats 2024-07-04 11:41:16 -07:00
hagen-danswer
0b384c5b34 fixed salesforce url generation (#1777) 2024-07-04 10:43:21 -07:00
Weves
fa049f4f98 Add UI support for github configs 2024-07-03 17:37:59 -07:00
pablodanswer
72d6a0ef71 minor updates to assistant UI (#1771) 2024-07-03 18:28:25 +00:00
pablodanswer
ae4e643266 Update Assistants Creation UI (#1714)
* slide up "Tools"

* rework assistants page

* update layout

* reorg complete

- pending: useful header text?

* add tooltips

* alter organizational structure

* rm shadcn

* rm dependencies

* revalidate dependencies

* restore

* update component structure

* [s] format

* rm package json

* add package-lock.json [s]

* collapsible

* naming + width

* formatting

* formatting

* updated user flow

- Fix error/detail messages
- Fix tooltip delay
- Fix icons

* 1 -> 2

* naming fixes

* ran pretty

* fix build issue?

* web build issues?
2024-07-03 17:11:14 +00:00
hagen-danswer
a7da07afc0 allowed arbitrary types to handle the sqlalchemy datatype (#1758)
* allowed arbitrary types to handle the sqlalchemy datatype

* changed persona_upsert to take in ids instead of objects
2024-07-03 07:10:57 +00:00
Weves
7f1bb67e52 Pass through API base to ImageGenerationTool 2024-07-02 23:31:04 -07:00
Weves
982b1b0c49 Add litellm.set_verbose support 2024-07-02 23:22:17 -07:00
Daniel Naber
2db128fb36 Notion date filter fix (#1755)
* fix filter logic

* make comparison better readable
2024-07-02 15:39:35 -07:00
Christoph Petzold
3ebac6256f Fix "cannot access local variable" for bot direct messages (#1737)
* Update handle_message.py

* Update handle_message.py

* Update handle_message.py

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
2024-07-02 15:36:23 -07:00
Weves
1a3ec59610 Fix build caused by bad seeding config 2024-07-01 23:41:43 -07:00
hagen-danswer
581cb827bb added settings and persona seeding options (#1742)
* added settings and persona seeding options

* updated recency_bias

* changed variable type

* another fix

* Update seeding.py

* fixed mypy

* push
2024-07-01 22:22:17 +00:00
Weves
393b3c9343 Fix misc chat bugs 2024-06-30 18:08:03 -07:00
Weves
2035e9f39c Fix docker build 2024-06-30 14:34:47 -07:00
Weves
52c3a5e9d2 Fix slackbot citation images 2024-06-30 13:42:56 -07:00
pablodanswer
3e45a41617 Bugfix/scroll (#1748) 2024-06-30 12:58:57 -07:00
Weves
415960564d Fix fast models 2024-06-29 15:19:09 -07:00
pablodanswer
ed550986a6 Feature/assistants (#1581)
* include alternate assisstant

- migrate models
- migrate db

* functional alternate assistant selection

* refactor chat components for persona API

* functional assistants api

* add full functionality- assistants

* add functional assistants dropdown handler

* refactor assistants for full compatability

- hooks
- track the live assistant for edge cases
- UI updates

* add assistant UI features

- Autotab
- Arrow selection
- Icons
- Proper @ detection
- Info Popup

prune unnecessary comments

* functional search toggling for assistants

* add functional cross-page assistants

rebase with main

* add proper interactivity for edge cases

- click outside of input / text box
- "force search" assistant consistency

* refactor alt assistant consistency

* update alembic versions

* rebased

* undo formatting changes

* additional formatting

* current processing

* merge fixes

* formatting

* colors

* 2 -> 1

* 1 -> 2

---------

Co-authored-by: “Pablo <“pablo@danswer.ai”>
2024-06-29 00:18:39 +00:00
hagen-danswer
60dd77393d Disallowed simultaneous pruning jobs (#1704)
* Added TTL to EE Celery tasks

* fixed alembic files

* fixed frontend build issue and reworked file deletion

* FileD

* revert change

* reworked delete chatmessage

* added orphan cleanup

* ensured syntax

* Disallowed simultaneous pruning jobs

* added rate limiting and env vars

* i hope this is how you use decorators

* nonsense

* cleaned up names, added config

* renamed other utils

* Update celery_utils.py

* reverted changes
2024-06-28 23:26:00 +00:00
hagen-danswer
3fe5313b02 renamed alembic table (#1741) 2024-06-28 22:54:19 +00:00
hagen-danswer
bd0925611a Added TTL to EE Celery tasks (#1713)
* Added TTL to EE Celery tasks

* fixed alembic files

* fixed frontend build issue and reworked file deletion

* FileD

* revert change

* reworked delete chatmessage

* added orphan cleanup

* ensured syntax

* default value to None

* made all deletions manual

* added fix

* Use tremor buttons now

* removed words

* Update 23957775e5f5_remove_feedback_foreignkey_constraint.py

* fixed alembic version
2024-06-28 22:13:47 +00:00
pablodanswer
de6d040349 add boto3 typing to default requirements (#1740) 2024-06-28 13:06:59 -07:00
Chris Weaver
38da3128d8 Pass headers into image generation (#1739) 2024-06-28 12:33:53 -07:00
Chris Weaver
e47da0d688 Small readme improvement (#1735) 2024-06-27 19:12:24 -07:00
rkuo-danswer
2c0e0c5f11 Merge pull request #1731 from danswer-ai/feature/merge-queue-workflows
add workflows that automate docker builds against merge group events
2024-06-27 18:25:05 -07:00
Richard Kuo (Danswer)
29d57f6354 remove obsolete comment 2024-06-27 17:39:17 -07:00
rkuo-danswer
369e607631 Fixes DAN-189 (safari bug in admin). Removed td/absolute positioning behavior which is unde… (#1718) 2024-06-27 17:15:42 -07:00
pablodanswer
f03f97307f Blob Storage (#1705)
S3 + OCI + Google Cloud Storage + R2
---------

Co-authored-by: Art Matsak <5328078+artmatsak@users.noreply.github.com>
2024-06-27 17:12:20 -07:00
Weves
145cdb69b7 Remove duplicate tool check 2024-06-27 17:05:50 -07:00
pablodanswer
9310a8edc2 Feature/scroll (#1694)
---------

Co-authored-by: “Pablo <“pablo@danswer.ai”>
2024-06-27 16:40:23 -07:00
Yuhong Sun
2140f80891 Tidy up Actions ported from EE (#1732) 2024-06-27 16:20:34 -07:00
Richard Kuo (Danswer)
52dab23295 add workflows that automate docker builds against merge group events 2024-06-27 14:31:39 -07:00
Weves
91c9b2eb42 Add more logging for num workers in simple job client 2024-06-27 11:49:08 -07:00
Weves
5764cdd469 Use FiEdit2 as the standard edit icon 2024-06-27 11:32:47 -07:00
Weves
8fea6d7f64 Fix share for insecure: 2024-06-27 11:19:04 -07:00
pablodanswer
5324b15397 Chat overflow (#1723) 2024-06-27 11:02:11 -07:00
Yuhong Sun
8be42a5f98 Touchup for Multilingual Users (#1725) 2024-06-26 22:44:06 -07:00
Weves
062dc98719 Fix search tool 2024-06-26 21:28:16 -07:00
pablodanswer
43557f738b add copy-paste images (#1722) 2024-06-26 20:12:34 -07:00
Weves
b5aa7370a2 Make seeded model default 2024-06-26 16:03:24 -07:00
Weves
4ba6e45128 Small template fix 2024-06-26 13:58:30 -07:00
pablodanswer
d6e5a98a22 Minor Update to UI (#1692)
New sidebar / chatbar color, hidable right-panel, and many more small tweaks.

---------

Co-authored-by: pablodanswer <“pablo@danswer.ai”>
2024-06-26 13:54:41 -07:00
Yuhong Sun
20c4cdbdda Catch LLM Generation Failure (#1712) 2024-06-26 10:44:22 -07:00
Yuhong Sun
0d814939ee Bugfix for Selected Doc when the message it is selected from failed (#1711) 2024-06-26 10:32:30 -07:00
Yuhong Sun
7d2b0ffcc5 Developer Env Setup (#1710) 2024-06-26 09:45:19 -07:00
hagen-danswer
8c6cd661f5 Ignore messages from Slack's official bot (#1703)
* Ignore messages from Slack's official bot

* re-added message filter
2024-06-25 20:33:36 -07:00
Weves
5d552705aa Change EE environment variable name 2024-06-25 16:59:09 -07:00
Weves
1ee8ee9e8b Prepare EE to merge with MIT 2024-06-25 15:07:56 -07:00
Chris Weaver
f0b2b57d81 Usage reports (#118)
---------

Co-authored-by: amohamdy99 <a.mohamdy99@gmail.com>
2024-06-25 15:07:56 -07:00
hagen-danswer
5c12a3e872 brought out the UsersResponse interface (#119) 2024-06-25 15:07:56 -07:00
Weves
3af81ca96b Fix seed config when left empty 2024-06-25 15:07:56 -07:00
Weves
f55e5415bb Add empty assets folder 2024-06-25 15:07:56 -07:00
Weves
3d434c2c9e Fix persona access for answer-with-quote API 2024-06-25 15:07:56 -07:00
pablodanswer
90ec156791 formatting 2024-06-25 15:07:56 -07:00
pablodanswer
8ba48e24a6 minor build fix 2024-06-25 15:07:56 -07:00
pablodanswer
e34bcbbd06 Add persistent name and logo seeding (#107) 2024-06-25 15:07:56 -07:00
pablodanswer
db319168f8 stronger wording 2024-06-25 15:07:56 -07:00
pablodanswer
010ce5395f Minor/ee optional branding (#105) 2024-06-25 15:07:56 -07:00
rashad-danswer
98a58337a7 Query history speed fix (#109) 2024-06-25 15:07:56 -07:00
Weves
733d4e666b Add support for private file connectors 2024-06-25 15:07:56 -07:00
Weves
2937fe9e7d Fix backend build 2024-06-25 15:07:56 -07:00
Weves
457527ac86 Try different runner groups for each build 2024-06-25 15:07:56 -07:00
Weves
7cc51376f2 Allow basic seeding of Danswer via env variable 2024-06-25 15:07:56 -07:00
Weves
7278d45552 Fix rebase issue 2024-06-25 15:07:56 -07:00
Yuhong Sun
1c343bbee7 Enable Dedup Flag for Doc Search Endpoint 2024-06-25 15:07:56 -07:00
Weves
bdcfb39724 Add whitelabeled name to login page 2024-06-25 15:07:56 -07:00
Weves
694d20ea8f Fix user groups issue from rebase 2024-06-25 15:07:56 -07:00
Weves
45402d0755 Add back custom logo/name to sidebar header 2024-06-25 15:07:56 -07:00
Weves
69740ba3d5 Fix rebase issue 2024-06-25 15:07:56 -07:00
Yuhong Sun
6162283beb Fix formatting issues (#93) 2024-06-25 15:07:56 -07:00
Yuhong Sun
44284f7912 Fix Rebase Issues (#92) 2024-06-25 15:07:56 -07:00
Weves
775ca5787b Move web build to a matrix build 2024-06-25 15:07:56 -07:00
Weves
c6e49a3034 Don't get duplicate docs during user group syncing 2024-06-25 15:07:56 -07:00
Weves
9c8cfd9175 Fix mypy 2024-06-25 15:07:56 -07:00
Weves
fc3ed76d12 Add pagination to user group syncing 2024-06-25 15:07:56 -07:00
Weves
a2597d5f21 Fix rebase issue with dev compose file 2024-06-25 15:07:56 -07:00
Yuhong Sun
af588461d2 Enable Encryption 2024-06-25 15:07:56 -07:00
Weves
460e61b3a7 Fix document lock acquisition for user group sync 2024-06-25 15:07:56 -07:00
Weves
c631ac0c3a Change secret name 2024-06-25 15:07:56 -07:00
Yuhong Sun
10be91a8cc Track Slack questions Autoresolved (#86) 2024-06-25 15:07:56 -07:00
Weves
eadad34a77 Fix /send-message-simple-api endpoint 2024-06-25 15:07:56 -07:00
Weves
b19d88a151 Fix rebase issue with file_store 2024-06-25 15:07:56 -07:00
Weves
e33b469915 Remove unused Chat.tsx file 2024-06-25 15:07:56 -07:00
Weves
719fc06604 Fix rebase issue with UI-based LLM selection 2024-06-25 15:07:56 -07:00
Alan Hagedorn
d7a704c0d9 Token Rate Limiting
WIP

Cleanup 🧹

Remove existing rate limiting logic

Cleanup 🧼

Undo nit

Cleanup 🧽

Move db constants (avoids circular import)

WIP

WIP

Cleanup

Lint

Resolve alembic conflict

Fix mypy

Add backfill to migration

Update comment

Make unauthenticated users still adhere to global limits

Use Depends

Remove enum from table

Update migration error handling + deletion

Address verbal feedback, cleanup urls, minor nits
2024-06-25 15:07:56 -07:00
Yuhong Sun
7a408749cf Fix Web Compile Issue (#81) 2024-06-25 15:07:56 -07:00
Yuhong Sun
d9acd03a85 Query History Include Feedback Text (#80) 2024-06-25 15:07:56 -07:00
Yuhong Sun
af94c092e7 Reduce sync jobs batch size (#79) 2024-06-25 15:07:56 -07:00
Yuhong Sun
f55a4ef9bd Remove Nested Session (#78) 2024-06-25 15:07:56 -07:00
Yuhong Sun
6c6e33e001 Allow Empty API Names (#77) 2024-06-25 15:07:56 -07:00
Yuhong Sun
336c046e5d Better Naming for API Keys (#76) 2024-06-25 15:07:56 -07:00
Weves
9a9b89f073 Fix rebase issue with public assistants 2024-06-25 15:07:56 -07:00
Weves
89fac98534 Fix ee redirect 2024-06-25 15:07:56 -07:00
Weves
65b65518de Add back ChatBanner 2024-06-25 15:07:56 -07:00
Yuhong Sun
0c827d1e6c Permission Sync Framework (#44) 2024-06-25 15:07:56 -07:00
Weves
1984f2c1ca Add automated auth checks for ee 2024-06-25 15:07:56 -07:00
Yuhong Sun
50f006557f Add message id to simple message endpoint (#69) 2024-06-25 15:07:56 -07:00
Yuhong Sun
c00bd44bcc Add Chunk Context options for EE APIs (#68) 2024-06-25 15:07:56 -07:00
Yuhong Sun
680aca68e5 Make EE containers public changes (#67) 2024-06-25 15:07:56 -07:00
Weves
22a2f86fb9 FE build fix 2024-06-25 15:07:56 -07:00
Weves
c055dc1535 Add custom analytics script 2024-06-25 15:07:56 -07:00
Alan Hagedorn
81e9880d9d Add names to API Keys (#63) 2024-06-25 15:07:56 -07:00
Weves
3466f6d3a4 Custom banner 2024-06-25 15:07:56 -07:00
Weves
91cf45165f Small fixes + adding 'Powered by Danswer' 2024-06-25 15:07:56 -07:00
Weves
ee2a5bbf49 Add custom-styling ability via themes 2024-06-25 15:07:56 -07:00
Weves
153007c57c Whitelableing for Logo / Name via Admin panel 2024-06-25 15:07:56 -07:00
Yuhong Sun
fa8cc10063 Allow Optional Rerank in APIs (#60) 2024-06-25 15:07:56 -07:00
Yuhong Sun
2c3ba5f021 Include User in Query Export (#59) 2024-06-25 15:07:56 -07:00
Yuhong Sun
e3ef620094 Query History Now Handles Old Messages (#58) 2024-06-25 15:07:56 -07:00
Yuhong Sun
40369e0538 Formatter (#57) 2024-06-25 15:07:56 -07:00
Yuhong Sun
d6c5c65b51 Fix Query History (#56) 2024-06-25 15:07:56 -07:00
Yuhong Sun
7b16cb9562 Rebase search changes to EE APIs (#55) 2024-06-25 15:07:56 -07:00
Weves
ef4f06a375 Fix SAML for /manage/me 2024-06-25 15:07:56 -07:00
Chris Weaver
17cc262f5d Private personas doc sets (#52)
Private Personas and Document Sets

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2024-06-25 15:07:56 -07:00
Yuhong Sun
680482bd06 Metadata filter for document search API (#53) 2024-06-25 15:07:56 -07:00
Weves
64874d2737 Change runner for backend 2024-06-25 15:07:56 -07:00
Weves
00ade322f1 Fix small answer-with-quote bug 2024-06-25 15:07:56 -07:00
Weves
eab5d054d5 Add env variable to control hash rounds 2024-06-25 15:07:56 -07:00
Weves
a09d60d7d0 Fix user group deletion bug 2024-06-25 15:07:56 -07:00
Yuhong Sun
f17dc52b37 One Shot API No Stream (#41) 2024-06-25 15:07:56 -07:00
Yuhong Sun
c1862e961b Simple API No Longer Require Specify Prompt (#40) 2024-06-25 15:07:56 -07:00
Yuhong Sun
6b46a71cb5 Fix Empty Chat for API (#39) 2024-06-25 15:07:56 -07:00
Yuhong Sun
9ae3a4af7f Basic Chat API (#38) 2024-06-25 15:07:56 -07:00
Weves
328b96c9ff Make public/not-public selector prettier 2024-06-25 15:07:56 -07:00
Yuhong Sun
bac34a47b2 Embedding Model Swap Changes (#35) 2024-06-25 15:07:56 -07:00
Weves
15934ee268 Fix limit/offset for document-search endpoint 2024-06-25 15:07:56 -07:00
Weves
fe975c3357 Add global prefix to EE endpoints 2024-06-25 15:07:56 -07:00
Weves
8bf483904d Fix users page with API keys + add spinner on key creation 2024-06-25 15:07:56 -07:00
Yuhong Sun
db338bfddf Introduce EE only Backend APIs (#29) 2024-06-25 15:07:56 -07:00
Weves
ae02a5199a Add API key generation in the UI + allow it to be used across all endpoints 2024-06-25 15:07:56 -07:00
Yuhong Sun
4b44073d9a CVEs (#26) 2024-06-25 15:07:56 -07:00
Weves
ce36530c79 Fix viewing other users' chat histories in query history 2024-06-25 15:07:56 -07:00
Weves
39d69838c5 Make query history fetch client-side 2024-06-25 15:07:56 -07:00
Weves
e11f0f6202 Fix /chat-session-history/{chat_session_id} endpoint when auth is enabled 2024-06-25 15:07:56 -07:00
Weves
ce870ff577 Re-style user group pages 2024-06-25 15:07:56 -07:00
Weves
a52711967f Fix analytics + query history 2024-06-25 15:07:56 -07:00
Weves
67a4eb6f6f Fix frontend typing rebase issue 2024-06-25 15:07:56 -07:00
Weves
9599388db8 Fix sidebar typo 2024-06-25 15:07:56 -07:00
Weves
f82ae158ea Mark indexing jobs as ee when running ee supervisord 2024-06-25 15:07:56 -07:00
Weves
670de6c00d Add new env variable to EE supervisord 2024-06-25 15:07:56 -07:00
Yuhong Sun
56c52bddff Fix missing supervisord change from Danswer MIT (#18) 2024-06-25 15:07:56 -07:00
Chris Weaver
3984350ff9 Improvements to Query History (#17)
* Add option to download query-history as a CSV

* Add user email + more complete timestamp
2024-06-25 15:07:56 -07:00
Weves
f799d9aa11 Fix EE import 2024-06-25 15:07:56 -07:00
Yuhong Sun
529f2c8c2d Danswer EE Version Text (#12) 2024-06-25 15:07:56 -07:00
Yuhong Sun
b4683dc841 Fix Rebase Issue 2024-06-25 15:07:56 -07:00
Chris Weaver
db8ce61ff4 Fix group prefix (#11) 2024-06-25 15:07:56 -07:00
Yuhong Sun
d016e8335e Update default SAML config location (#10) 2024-06-25 15:07:56 -07:00
Yuhong Sun
0c295d1de5 Enable EE features for no-letsencrypt deployment (#9) 2024-06-25 15:07:56 -07:00
Chris Weaver
e9f273d99a Admin Analytics/Query History dashboards (#6) 2024-06-25 15:07:56 -07:00
Weves
428f5edd21 Update ee supervisord 2024-06-25 15:07:56 -07:00
Weves
50170cc97e Move user group syncing to Celery Beat 2024-06-25 15:07:56 -07:00
Chris Weaver
7503f8f37b Add User Groups (a.k.a. RBAC) (#4) 2024-06-25 15:07:56 -07:00
Yuhong Sun
92de6acc6f Initial EE features (#3) 2024-06-25 15:07:56 -07:00
mattboret
65d5808ea7 Confluence: add pages labels indexation (#1635)
* Confluence: add pages labels indexation

* changed the default and fixed the dict building

* Update app_configs.py

* Update connector.py

---------

Co-authored-by: Matthieu Boret <matthieu.boret@fr.clara.net>
Co-authored-by: hagen-danswer <hagen@danswer.ai>
Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2024-06-25 10:41:29 -07:00
Yuhong Sun
061dab7f37 Touchup (#1702) 2024-06-25 10:34:03 -07:00
hagen-danswer
e65d9e155d fixed confluence breaking on unknown filetypes (#1698) 2024-06-25 10:19:01 -07:00
hagen-danswer
50f799edf4 Merge pull request #1672 from danswer-ai/add-groups-to-slack-bot-responses
add slack groups to user response list
2024-06-24 15:10:12 -07:00
pablodanswer
c1d8f6cb66 add basic 403 support for healthcheck (#1689) 2024-06-23 11:19:46 -07:00
pablodanswer
6c71bc05ea modify script deletion name (#1690) 2024-06-23 08:29:37 -07:00
Yuhong Sun
123ec4342a Relari (#1687)
Also includes some bugfixes
2024-06-22 18:52:48 -07:00
pablodanswer
7253316b9e Add script for forced connector deletion (#1683) 2024-06-22 17:15:25 -07:00
Weves
4ae924662c Add migration for usage reports 2024-06-22 16:21:41 -07:00
Yuhong Sun
094eea2742 Discourse Edge Case (#1685) 2024-06-22 15:17:33 -07:00
pablodanswer
8178d536b4 Add functional thread modification endpoints (#1668)
Makes it so if you change which LLM you are using in a given ChatSession, that is persisted and sticks around if you reload the page / come back to the ChatSession later
2024-06-21 18:10:30 -07:00
Yuhong Sun
5cafc96cae Enable Internet Search for Deployment Options (#1684) 2024-06-21 17:38:49 -07:00
Weves
3e39a921b0 Fix image generation output 2024-06-21 15:41:55 -07:00
Weves
98b2507045 Improve persona access 2024-06-21 13:35:14 -07:00
hagen-danswer
3dfe17a54d google drive ignores shortcut filetypes now 2024-06-20 18:50:12 -07:00
hagen-danswer
b4675082b1 Update handle_message.py 2024-06-20 21:16:26 -04:00
hagen-danswer
287a706e89 combined the input fields 2024-06-20 11:48:14 -07:00
Kevin Shi
ba58208a85 Transform HTML links to markdown behind config option (#1671) 2024-06-20 10:43:15 -07:00
hagen-danswer
694e9e8679 finished first draft 2024-06-19 22:11:33 -07:00
pablodanswer
9e30ec1f1f hide popup for non admin + if search is disabled 2024-06-19 13:08:24 -07:00
pablodanswer
1b56c75527 [minor] proper assistant line length 2024-06-19 12:14:42 -07:00
Weves
b07fdbf1d1 Cleanup user management 2024-06-18 14:37:47 -07:00
hagen-danswer
54c2547d89 Add connector document pruning task (#1652) 2024-06-18 14:26:12 -07:00
Liam Norris
58b5e25c97 User Management: Invite, Deactivate, Search, & Paginate (#1631) 2024-06-18 11:28:47 -07:00
hagen-danswer
4e15ba78d5 replicated drive fix for gmail connector (#1658) 2024-06-18 08:47:06 -07:00
Yuhong Sun
c798ade127 Code for ease of eval (#1656) 2024-06-17 20:32:12 -07:00
hagen-danswer
93cc5a9e77 improved salesforce description text (#1655) 2024-06-17 20:14:34 -07:00
Weves
7746375bfd Custom tools 2024-06-17 15:12:50 -07:00
Weves
c6d094b2ee Fix google drive connector page refresh 2024-06-16 14:56:36 -07:00
hagen-danswer
7a855192c3 added google slides format to gdrive connector (#1645) 2024-06-16 13:59:20 -07:00
hagen-danswer
c3577cf346 salesforce qol changes (#1644) 2024-06-16 11:12:37 -07:00
hagen-danswer
722a1dd919 Add salesforce connector (#1621) 2024-06-16 10:04:44 -07:00
hagen-danswer
e4999266ca added azure models to vision capable list (#1638) 2024-06-16 08:38:30 -07:00
Weves
f294dba095 Fix google drive page 2024-06-15 16:04:21 -07:00
Weves
03105ad551 Fix bypass_acl support for Slack bot 2024-06-14 16:57:06 -07:00
hagen-danswer
4b0ff95b26 added pptx to drive reader (#1634) 2024-06-13 22:50:28 -07:00
Vikas Neha Ojha
ff06d62acf Added ClickUp Connector (#1521)
* Added connector for clickup

* Fixed mypy issues

* Fallback to description if markdown is not available

* Added extra information in metadata, and support to index comments

* Fixes for fields parsing

* updated fetcher to errorHandlingFetcher

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
2024-06-13 17:49:31 -07:00
Yuhong Sun
26fee36ed4 Catch Slack Greeting Msgs Generic (#1633) 2024-06-13 09:41:01 -07:00
Scott Davidson
428439447e Fix Helm vespa resource limits (#1606)
Co-authored-by: sd109 <scott@stackhpc.com>
2024-06-11 21:01:26 -07:00
hagen-danswer
e8cfbc1dd8 added a check for empty URL list in web connector (#1573)
* added a check for empty URL list in web connector

* added raise condition for improper sitemap designation
2024-06-11 18:26:44 -07:00
hagen-danswer
486b0ecb31 Confluence: Add page attachments indexing (#1617)
* Confluence: Add page attachments indexing

* used the centralized file processing to extract file content

* flipped input order for extract_file_text

* added bytes support for pdf converter

* brought out the io.BytesIO to the confluence connector

---------

Co-authored-by: Matthieu Boret <matthieu.boret@fr.clara.net>
2024-06-11 18:23:13 -07:00
dependabot[bot]
8c324f8f01 Bump requests from 2.31.0 to 2.32.2 in /backend/requirements (#1509)
Bumps [requests](https://github.com/psf/requests) from 2.31.0 to 2.32.2.
- [Release notes](https://github.com/psf/requests/releases)
- [Changelog](https://github.com/psf/requests/blob/main/HISTORY.md)
- [Commits](https://github.com/psf/requests/compare/v2.31.0...v2.32.2)

---
updated-dependencies:
- dependency-name: requests
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-06-11 18:20:08 -07:00
Varun Gaur
5a577f9a00 Added a Logic to Index entire Gitlab Project (#1586)
* Changes for Gitlab connector

* Changes to Rebase from Main

* Changes to Rebase from Main

* Changes to Rebase from Main

* Changes to Rebase from Main

* made indexing code files a config setting

* Update app_configs.py

created env variable

* Update app_configs.py

added false

---------

Co-authored-by: Varun Gaur <vgaur@roku.com>
Co-authored-by: hagen-danswer <hagen@danswer.ai>
2024-06-11 18:18:14 -07:00
Weves
e6d5b95b4a Fix web build 2024-06-11 15:06:55 -07:00
Weves
cc0320b50a Apply passthrough headers for chat renaming 2024-06-10 18:24:46 -07:00
Yuhong Sun
36afa9370f Vespa remove apostrophe in URLs (#1618) 2024-06-10 17:19:47 -07:00
Yuhong Sun
7c9d037b7c Env Template Update (#1615) 2024-06-10 14:24:17 -07:00
Yuhong Sun
a2065f018a Environment Template for VSCode/Cursor (#1614) 2024-06-10 13:36:20 -07:00
Weves
b723627e0c Ability to pass through headers to LLM call 2024-06-10 13:18:31 -07:00
hagen-danswer
180b592afe added html support to file text extractor (#1611) 2024-06-10 12:46:05 -07:00
Weves
e8306b0fa5 Fix web build 2024-06-10 11:40:49 -07:00
Weves
64ee5ffff5 Fix slack bot creation with document sets 2024-06-10 11:29:44 -07:00
hagen-danswer
ead6a851cc Merge pull request #1144 from hagen6835/add-teams-connector
added teams connector
2024-06-10 13:25:54 -04:00
hagen-danswer
73575f22d8 prettier 2024-06-10 09:44:20 -07:00
hagen-danswer
be5dd3eefb final revisions fr 2024-06-10 09:32:39 -07:00
hagen-danswer
f18aa2368e Merge pull request #1601 from danswer-ai/prune-model-list
chatpage now checks for llm override for image uploads
2024-06-09 20:17:13 -04:00
hagen-danswer
3ec559ade2 added null inputs for other usages 2024-06-09 17:11:06 -07:00
hagen-danswer
4d0794f4f5 chatpage now checks for llm override for image uploads 2024-06-09 17:05:41 -07:00
hagen-danswer
64a042b94d cleaned up sharepoint connector (#1599)
* cleaned up sharepoint connector

* additional cleanup

* fixed dropbox connector string
2024-06-09 12:15:52 -07:00
Yuhong Sun
fa3a3d348c Precommit Fixes (#1596) 2024-06-09 00:44:36 -07:00
mattboret
a0e10ac9c2 Slack: add support to rephrase user message (#1528)
* Slack: add support to rephrase user message

* fix: handle rephrase error

* Update listener.py

---------

Co-authored-by: Matthieu Boret <matthieu.boret@fr.clara.net>
Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2024-06-08 18:11:32 -07:00
Art Matsak
e1ece4a27a Fix Helm chart run failures due to low resources (#1536)
Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2024-06-08 18:05:44 -07:00
Bijay Regmi
260149b35a Feature/support all gpus (#1515)
* add `count` explicitely to ensure backwards compat

* fix main document URL
2024-06-08 17:45:26 -07:00
Alexander L
5f2737f9ee Add OAuth env vars to API Server deployment (#1500)
* Add oauth vars to the kubernetes api server deployment

* Add example google oauth id and secret to kubernetes secrets
2024-06-08 17:43:36 -07:00
Moncho Pena
e1d8b88318 Changed the vespa podLabels to work properly (#1594)
Vespa is not working because this configuration

As you can see in this Issue https://github.com/unoplat/vespa-helm-charts/issues/20

You have to use this podLabels to be accord with the other configuration.

vespa:
  podLabels:
    app: vespa
    app.kubernetes.io/instance: danswer <-------------
2024-06-08 17:40:27 -07:00
Joe Shindelar
fc5337d4db Set ignore_danswer_metadata=False when calling read_text_file in file connector. (#1501)
Fixes an issue where metadata specified in a #DANSWER_METADATA line in a file read by the file connector is ignored.
2024-06-08 17:39:18 -07:00
hagen-danswer
bd9335e832 disabled reindexing for dropbox and added warning (#1593) 2024-06-08 17:26:07 -07:00
hagen-danswer
cbc53fd500 moved methods to top and fixed logic errors 2024-06-07 20:34:24 -07:00
Nils
7a3c102c74 Added support for Sites.Selected permissions to the SharePoint Connector and enabled the selection of individual subfolders (#1583) 2024-06-07 15:47:17 -07:00
Weves
4274c114c5 Fix display for std markdown without a language 2024-06-07 15:22:06 -07:00
Weves
d56e6c495a Fix code block copy 2024-06-07 15:22:06 -07:00
Weves
c9160c705a Fix list of assistants in Assistants Modal 2024-06-07 14:21:57 -07:00
Chris Weaver
3bc46ef40e Fix slack bot with document set (#1588)
Also includes a Persona db layer refactor
2024-06-07 14:14:08 -07:00
hagen-danswer
ff59858327 final bugfixes 2024-06-07 13:19:48 -07:00
Yuhong Sun
643176407c Fix Dedupe (#1587) 2024-06-07 11:35:27 -07:00
Weves
eacfd8f33f Use errorHandlingFetcher 2024-06-07 11:25:48 -07:00
Weves
f6fb963419 Fix indexing status page crashing 2024-06-07 11:25:48 -07:00
hagen-danswer
16e023a8ce Revert "ran prettier"
This reverts commit 750c1df0bb.
2024-06-07 11:20:58 -07:00
hagen-danswer
b79820a309 Revert "Update init-letsencrypt.sh (#1380)"
This reverts commit 9e0b6aa531.
2024-06-07 11:05:57 -07:00
hagen-danswer
754b735174 Revert "fix gitlab-connector - wrong datetime format (#1559)"
This reverts commit 8dfba97c09.
2024-06-07 11:00:01 -07:00
hagen-danswer
58c305a539 Revert "Add Dropbox connector (#956)"
This reverts commit 914dc27a8f.
2024-06-07 10:58:37 -07:00
hagen-danswer
26bc785625 Revert "Update README.md with fixed Slack link round 2"
This reverts commit 0b6e85c26b.
2024-06-07 10:53:15 -07:00
Yuhong Sun
09da456bba Remove Redundant Dedupe Logic (#1577) 2024-06-06 14:36:41 -07:00
Yuhong Sun
da43bac456 Dedupe Flag (#1576) 2024-06-06 14:10:40 -07:00
Weves
adcbd354f4 Fix fast model update + slight reword on model selection 2024-06-05 18:43:37 -07:00
Weves
41fbaf5698 Fix prompt access 2024-06-05 18:43:13 -07:00
Hagen O'Neill
0b83396c4d disabled dropbox polling 2024-06-05 15:14:08 -07:00
Hagen O'Neill
785d7736ed extracted semantic identifier into its own method 2024-06-05 14:37:53 -07:00
Hagen O'Neill
9a9a879aee bugfixes 2024-06-05 14:30:34 -07:00
Hagen O'Neill
7b36f7aa4f chat_message: ChatMessage 2024-06-05 14:18:28 -07:00
Hagen O'Neill
8d74176348 completed code revision suggestions 2024-06-05 14:11:38 -07:00
Hagen O'Neill
713d325f42 fixed rebase issues 2024-06-04 21:25:16 -07:00
Hagen O'Neill
f34b26b3d0 seperated teams and sharepoint enviornment variables 2024-06-04 20:08:24 -07:00
Hagenoneill
a2349af65c added teams connector 2024-06-04 20:08:07 -07:00
Bill Yang
914dc27a8f Add Dropbox connector (#956)
* start dropbox connector

* add wip ui

* polish ui

* Fix some ci

* ignore types

* addressed, fixed, and tested all comments

* ran prettier

* ran mypy fixes

---------

Co-authored-by: Bill Yang <bill@Bills-MacBook-Pro.local>
Co-authored-by: hagen-danswer <hagen@danswer.ai>
2024-06-04 20:04:50 -07:00
Chris Weaver
0b6e85c26b Update README.md with fixed Slack link round 2 2024-06-04 20:04:50 -07:00
Chris Weaver
291a3f9ca0 Fix Slack link in README.md 2024-06-04 20:04:50 -07:00
NP
8dfba97c09 fix gitlab-connector - wrong datetime format (#1559) 2024-06-04 20:04:50 -07:00
Keiron Stoddart
9e0b6aa531 Update init-letsencrypt.sh (#1380)
Add ability to accept a `.ev.nginx` configuration that lists a subdomain.
2024-06-04 20:04:50 -07:00
Bill Yang
1fb47d70b3 Add Dropbox connector (#956)
* start dropbox connector

* add wip ui

* polish ui

* Fix some ci

* ignore types

* addressed, fixed, and tested all comments

* ran prettier

* ran mypy fixes

---------

Co-authored-by: Bill Yang <bill@Bills-MacBook-Pro.local>
Co-authored-by: hagen-danswer <hagen@danswer.ai>
2024-06-04 17:58:01 -07:00
Chris Weaver
25d40f8daa Update README.md with fixed Slack link round 2 2024-06-04 17:19:59 -07:00
Chris Weaver
154cdec0db Fix Slack link in README.md 2024-06-04 16:42:33 -07:00
Hagen O'Neill
f5deb37fde fixed mypy issues 2024-06-04 13:35:30 -07:00
Hagen O'Neill
750c1df0bb ran prettier 2024-06-04 13:23:02 -07:00
Hagen O'Neill
e608febb7f fixed messages nesting 2024-06-04 13:17:19 -07:00
hagen-danswer
c8e10282b2 Update connector.py
removed unused var
2024-06-04 15:49:05 -04:00
hagen-danswer
d7f66ba8c4 Update connector.py
changed raise output
2024-06-04 15:46:55 -04:00
NP
4d5a39628f fix gitlab-connector - wrong datetime format (#1559) 2024-06-04 10:41:19 -07:00
Keiron Stoddart
b6d0ecec4f Update init-letsencrypt.sh (#1380)
Add ability to accept a `.ev.nginx` configuration that lists a subdomain.
2024-06-04 10:38:01 -07:00
Hagen O'Neill
14a39e88e8 seperated teams and sharepoint enviornment variables 2024-06-04 09:45:10 -07:00
hagen-danswer
ea71b9830c Update types.ts
removed semi
2024-06-04 12:11:19 -04:00
hagen-danswer
a9834853ef Merge branch 'main' into add-teams-connector 2024-06-04 12:09:52 -04:00
hagen-danswer
8b10535c93 Merge pull request #1505 from mboret/fix/sharepoint-connector-missing-drive-items
fix sharepoint connector missing objects
2024-06-04 10:27:22 -04:00
pablodanswer
e1e1f036a7 Remove React Markdown for human messages (#1562)
* Remove React Markdown for human messages

* Update globals.css

* add formatting

* formatting

---------

Co-authored-by: “Pablo <“pablo@danswer.ai”>
2024-06-03 18:20:46 -07:00
Weves
0c642f25dd Add image generation for gpt-4o 2024-06-03 16:38:49 -07:00
Weves
3788041115 Fix missing end of answer for quotes processor 2024-06-03 16:37:27 -07:00
Weves
5a75470d23 Fix no-search assistants with DISABLE_LLM_CHOOSE_SEARCH enabled 2024-06-03 14:38:42 -07:00
Weves
81aada7c0f Add option to hide logout 2024-06-03 12:13:40 -07:00
Chris Weaver
e4a08c5546 Improve error msg in chat UI (#1557) 2024-06-02 19:53:04 -07:00
Weves
61d096533c Allow multiple files to be selected for file upload 2024-06-02 16:20:59 -07:00
Weves
0543abac9a Move to matrix builds 2024-06-02 15:57:53 -07:00
Weves
1d0ce49c05 Fix slowness due to hitting async Postgres driver pool limit 2024-06-01 20:18:14 -07:00
Weves
6e9d7acb9c Latency logging 2024-06-01 20:18:14 -07:00
Yuhong Sun
026652d827 Helm tuning (#1553) 2024-06-01 17:29:58 -07:00
Yuhong Sun
2363698c20 Consolidate Helm Charts (#1552) 2024-06-01 16:46:44 -07:00
Clay Rosenthal
0ea257d030 adding secrets to helm (#1541)
* adding secrets to helm

* incrementing chart version
2024-05-31 19:11:36 -07:00
Yuhong Sun
d141e637d0 Disable Search if User uploads files in Chat (#1550) 2024-05-31 19:07:56 -07:00
Yuhong Sun
4b53cb56a6 Fix File Type Migration (#1549) 2024-05-31 18:35:36 -07:00
Weves
b690ae05b4 Add assistant gallery 2024-05-29 21:05:31 -07:00
Matthieu Boret
fbdf882299 fix sharepoint connector missing objects 2024-05-29 10:13:41 +02:00
Yuhong Sun
44d57f1b53 Reenable force search (#1531) 2024-05-28 11:36:02 -07:00
Art Matsak
d2b58bdb40 Fix DISABLE_LLM_CHOOSE_SEARCH being ignored (#1523) 2024-05-28 08:39:02 -07:00
Yuhong Sun
aa98200bec SlackBot Disable AI option (#1527) 2024-05-28 01:35:57 -07:00
Yuhong Sun
32c37f8b17 Update Prompt (#1526) 2024-05-28 01:08:29 -07:00
Weves
008a91bff0 Partial fix for links not working 2024-05-27 17:31:10 -07:00
Weves
9a3613eb44 Fix unknown languages causing the chat to crash 2024-05-27 17:31:10 -07:00
Yuhong Sun
90d5b41901 Fix Citation Prompt Optionality (#1524) 2024-05-27 14:11:58 -07:00
Yuhong Sun
8688226003 Remove Tag Unique Constraint Bug (#1519) 2024-05-26 15:10:56 -07:00
Weves
97d058b8b2 Fix mypy for mediawiki tests 2024-05-25 17:16:47 -07:00
Weves
26ef5b897d Code highlighting 2024-05-25 17:12:12 -07:00
Weves
dfd233b985 Fix mypy for mediawiki connector 2024-05-25 13:04:23 -07:00
Weves
2dab9c576c Add Slack payload log 2024-05-24 19:24:17 -07:00
Weves
a9f5952510 Fix migration multiple head issue 2024-05-24 14:43:52 -07:00
Andrew Sansom
94018e83b0 Add MediaWiki and Wikipedia Connectors (#1250)
* Add MediaWikiConnector first draft

* Add MediaWikiConnector first draft

* Add MediaWikiConnector first draft

* Add MediaWikiConnector sections for each document

* Add MediaWikiConnector to constants and factory

* Integrate MediaWikiConnector with connectors page

* Unit tests + bug fixes

* Allow adding multiple mediawikiconnectors

* add wikipedia connector

* add wikipedia connector to factory

* improve docstrings of mediawiki connector backend

* improve docstrings of mediawiki connector backend

* move wikipedia and mediawiki icon locations in admin page

* undo accidental commit of modified docker compose yaml
2024-05-24 08:51:20 -07:00
Hagenoneill
818dfd0413 JUST get_all LOL 2024-04-01 14:39:37 -04:00
Hagenoneill
51b4e63218 organized documents by post instead of by channel 2024-04-01 14:31:26 -04:00
Hagenoneill
73b063b66c added teams connector 2024-02-29 12:36:05 -05:00
728 changed files with 50197 additions and 10636 deletions

View File

@@ -0,0 +1,33 @@
name: Build Backend Image on Merge Group
on:
merge_group:
types: [checks_requested]
env:
REGISTRY_IMAGE: danswer/danswer-backend
jobs:
build:
# TODO: make this a matrix build like the web containers
runs-on:
group: amd64-image-builders
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Backend Image Docker Build
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64,linux/arm64
push: false
tags: |
${{ env.REGISTRY_IMAGE }}:latest
build-args: |
DANSWER_VERSION=v0.0.1

View File

@@ -5,9 +5,14 @@ on:
tags:
- '*'
env:
REGISTRY_IMAGE: danswer/danswer-backend
jobs:
build-and-push:
runs-on: ubuntu-latest
# TODO: make this a matrix build like the web containers
runs-on:
group: amd64-image-builders
steps:
- name: Checkout code
@@ -30,8 +35,8 @@ jobs:
platforms: linux/amd64,linux/arm64
push: true
tags: |
danswer/danswer-backend:${{ github.ref_name }}
danswer/danswer-backend:latest
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.REGISTRY_IMAGE }}:latest
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
@@ -39,6 +44,6 @@ jobs:
uses: aquasecurity/trivy-action@master
with:
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
image-ref: docker.io/danswer/danswer-backend:${{ github.ref_name }}
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'
trivyignores: ./backend/.trivyignore

View File

@@ -5,40 +5,115 @@ on:
tags:
- '*'
env:
REGISTRY_IMAGE: danswer/danswer-web-server
jobs:
build-and-push:
runs-on: ubuntu-latest
build:
runs-on:
group: ${{ matrix.platform == 'linux/amd64' && 'amd64-image-builders' || 'arm64-image-builders' }}
strategy:
fail-fast: false
matrix:
platform:
- linux/amd64
- linux/arm64
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Prepare
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
tags: |
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
type=raw,value=${{ env.REGISTRY_IMAGE }}:latest
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push by digest
id: build
uses: docker/build-push-action@v5
with:
context: ./web
file: ./web/Dockerfile
platforms: ${{ matrix.platform }}
push: true
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
- name: Export digest
run: |
mkdir -p /tmp/digests
digest="${{ steps.build.outputs.digest }}"
touch "/tmp/digests/${digest#sha256:}"
- name: Upload digest
uses: actions/upload-artifact@v4
with:
name: digests-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*
if-no-files-found: error
retention-days: 1
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
merge:
runs-on: ubuntu-latest
needs:
- build
steps:
- name: Download digests
uses: actions/download-artifact@v4
with:
path: /tmp/digests
pattern: digests-*
merge-multiple: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Create manifest list and push
working-directory: /tmp/digests
run: |
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
- name: Inspect image
run: |
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Web Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./web
file: ./web/Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: |
danswer/danswer-web-server:${{ github.ref_name }}
danswer/danswer-web-server:latest
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# needed due to weird interactions with the builds for different platforms
no-cache: true
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
with:
image-ref: docker.io/danswer/danswer-web-server:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
with:
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'

View File

@@ -0,0 +1,53 @@
name: Build Web Image on Merge Group
on:
merge_group:
types: [checks_requested]
env:
REGISTRY_IMAGE: danswer/danswer-web-server
jobs:
build:
runs-on:
group: ${{ matrix.platform == 'linux/amd64' && 'amd64-image-builders' || 'arm64-image-builders' }}
strategy:
fail-fast: false
matrix:
platform:
- linux/amd64
- linux/arm64
steps:
- name: Prepare
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
tags: |
type=raw,value=${{ env.REGISTRY_IMAGE }}:latest
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build by digest
id: build
uses: docker/build-push-action@v5
with:
context: ./web
file: ./web/Dockerfile
platforms: ${{ matrix.platform }}
push: false
build-args: |
DANSWER_VERSION=v0.0.1
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

2
.gitignore vendored
View File

@@ -5,3 +5,5 @@
.idea
/deployment/data/nginx/app.conf
.vscode/launch.json
*.sw?
/backend/tests/regression/answer_quality/search_test_config.yaml

52
.vscode/env_template.txt vendored Normal file
View File

@@ -0,0 +1,52 @@
# Copy this file to .env at the base of the repo and fill in the <REPLACE THIS> values
# This will help with development iteration speed and reduce repeat tasks for dev
# Also check out danswer/backend/scripts/restart_containers.sh for a script to restart the containers which Danswer relies on outside of VSCode/Cursor processes
# For local dev, often user Authentication is not needed
AUTH_TYPE=disabled
# Always keep these on for Dev
# Logs all model prompts to stdout
LOG_DANSWER_MODEL_INTERACTIONS=True
# More verbose logging
LOG_LEVEL=debug
# This passes top N results to LLM an additional time for reranking prior to answer generation
# This step is quite heavy on token usage so we disable it for dev generally
DISABLE_LLM_CHUNK_FILTER=True
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)
OAUTH_CLIENT_ID=<REPLACE THIS>
OAUTH_CLIENT_SECRET=<REPLACE THIS>
# Generally not useful for dev, we don't generally want to set up an SMTP server for dev
REQUIRE_EMAIL_VERIFICATION=False
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
GEN_AI_API_KEY=<REPLACE THIS>
# If answer quality isn't important for dev, use 3.5 turbo due to it being cheaper
GEN_AI_MODEL_VERSION=gpt-3.5-turbo
FAST_GEN_AI_MODEL_VERSION=gpt-3.5-turbo
# For Danswer Slack Bot, overrides the UI values so no need to set this up via UI every time
# Only needed if using DanswerBot
#DANSWER_BOT_SLACK_APP_TOKEN=<REPLACE THIS>
#DANSWER_BOT_SLACK_BOT_TOKEN=<REPLACE THIS>
# Python stuff
PYTHONPATH=./backend
PYTHONUNBUFFERED=1
# Internet Search
BING_API_KEY=<REPLACE THIS>
# Enable the full set of Danswer Enterprise Edition features
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False

View File

@@ -17,6 +17,7 @@
"request": "launch",
"cwd": "${workspaceRoot}/web",
"runtimeExecutable": "npm",
"envFile": "${workspaceFolder}/.env",
"runtimeArgs": [
"run", "dev"
],
@@ -28,6 +29,7 @@
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
@@ -45,8 +47,9 @@
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_ALL_MODEL_INTERACTIONS": "True",
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
},
@@ -63,6 +66,7 @@
"request": "launch",
"program": "danswer/background/update.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"env": {
"ENABLE_MINI_CHUNK": "false",
"LOG_LEVEL": "DEBUG",
@@ -77,7 +81,9 @@
"request": "launch",
"program": "scripts/dev_run_background_jobs.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
@@ -100,6 +106,24 @@
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
}
},
{
"name": "Pytest",
"type": "python",
"request": "launch",
"module": "pytest",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-v"
// Specify a sepcific module/test to run or provide nothing to run all tests
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
]
}
]
}
}

View File

@@ -72,6 +72,10 @@ For convenience here's a command for it:
python -m venv .venv
source .venv/bin/activate
```
--> Note that this virtual environment MUST NOT be set up WITHIN the danswer
directory
_For Windows, activate the virtual environment using Command Prompt:_
```bash
.venv\Scripts\activate

View File

@@ -1,6 +1,10 @@
MIT License
Copyright (c) 2023-present DanswerAI, Inc.
Copyright (c) 2023 Yuhong Sun, Chris Weaver
Portions of this software are licensed as follows:
* All content that resides under "ee" directories of this repository, if that directory exists, is licensed under the license defined in "backend/ee/LICENSE". Specifically all content under "backend/ee" and "web/src/app/ee" is licensed under the license defined in "backend/ee/LICENSE".
* All third party components incorporated into the Danswer Software are licensed under the original license provided by the owner of the applicable component.
* Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View File

@@ -11,7 +11,7 @@
<a href="https://docs.danswer.dev/" target="_blank">
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
</a>
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2afut44lv-Rw3kSWu6_OmdAXRpCv80DQ" target="_blank">
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ" target="_blank">
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
</a>
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
@@ -105,5 +105,25 @@ Efficiently pulls the latest changes from:
* Websites
* And more ...
## 📚 Editions
There are two editions of Danswer:
* Danswer Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Danswer you will get if you follow the Deployment guide above.
* Danswer Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
* Single Sign-On (SSO), with support for both SAML and OIDC
* Role-based access control
* Document permission inheritance from connected sources
* Usage analytics and query history accessible to admins
* Whitelabeling
* API key authentication
* Encryption of secrets
* Any many more! Checkout [our website](https://www.danswer.ai/) for the latest.
To try the Danswer Enterprise Edition:
1. Checkout our [Cloud product](https://app.danswer.ai/signup).
2. For self-hosting, contact us at [founders@danswer.ai](mailto:founders@danswer.ai) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.

2
backend/.gitignore vendored
View File

@@ -5,7 +5,7 @@ site_crawls/
.ipynb_checkpoints/
api_keys.py
*ipynb
.env
.env*
vespa-app.zip
dynamic_config_storage/
celerybeat-schedule*

View File

@@ -1,15 +1,17 @@
FROM python:3.11.7-slim-bookworm
LABEL com.danswer.maintainer="founders@danswer.ai"
LABEL com.danswer.description="This image is for the backend of Danswer. It is MIT Licensed and \
free for all to use. You can find it at https://hub.docker.com/r/danswer/danswer-backend. For \
more details, visit https://github.com/danswer-ai/danswer."
LABEL com.danswer.description="This image is the web/frontend container of Danswer which \
contains code for both the Community and Enterprise editions of Danswer. If you do not \
have a contract or agreement with DanswerAI, you are not permitted to use the Enterprise \
Edition features outside of personal development or testing purposes. Please reach out to \
founders@danswer.ai for more information. Please visit https://github.com/danswer-ai/danswer"
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
ARG DANSWER_VERSION=0.3-dev
ENV DANSWER_VERSION=${DANSWER_VERSION}
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
# Install system dependencies
# cmake needed for psycopg (postgres)
# libpq-dev needed for psycopg (postgres)
@@ -17,18 +19,32 @@ RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
# zip for Vespa step futher down
# ca-certificates for HTTPS
RUN apt-get update && \
apt-get install -y cmake curl zip ca-certificates libgnutls30=3.7.9-2+deb12u2 \
libblkid1=2.38.1-5+deb12u1 libmount1=2.38.1-5+deb12u1 libsmartcols1=2.38.1-5+deb12u1 \
libuuid1=2.38.1-5+deb12u1 && \
apt-get install -y \
cmake \
curl \
zip \
ca-certificates \
libgnutls30=3.7.9-2+deb12u3 \
libblkid1=2.38.1-5+deb12u1 \
libmount1=2.38.1-5+deb12u1 \
libsmartcols1=2.38.1-5+deb12u1 \
libuuid1=2.38.1-5+deb12u1 \
libxmlsec1-dev \
pkg-config \
gcc && \
rm -rf /var/lib/apt/lists/* && \
apt-get clean
# Install Python dependencies
# Remove py which is pulled in by retry, py is not needed and is a CVE
COPY ./requirements/default.txt /tmp/requirements.txt
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \
COPY ./requirements/ee.txt /tmp/ee-requirements.txt
RUN pip install --no-cache-dir --upgrade \
-r /tmp/requirements.txt \
-r /tmp/ee-requirements.txt && \
pip uninstall -y py && \
playwright install chromium && playwright install-deps chromium && \
playwright install chromium && \
playwright install-deps chromium && \
ln -s /usr/local/bin/supervisord /usr/bin/supervisord
# Cleanup for CVEs and size reduction
@@ -36,11 +52,20 @@ RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \
# xserver-common and xvfb included by playwright installation but not needed after
# perl-base is part of the base Python Debian image but not needed for Danswer functionality
# perl-base could only be removed with --allow-remove-essential
RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cmake \
libldap-2.5-0 libldap-2.5-0 && \
RUN apt-get update && \
apt-get remove -y --allow-remove-essential \
perl-base \
xserver-common \
xvfb \
cmake \
libldap-2.5-0 \
libxmlsec1-dev \
pkg-config \
gcc && \
apt-get install -y libxmlsec1-openssl && \
apt-get autoremove -y && \
rm -rf /var/lib/apt/lists/* && \
rm /usr/local/lib/python3.11/site-packages/tornado/test/test.key
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
# Pre-downloading models for setups with limited egress
RUN python -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('intfloat/e5-base-v2')"
@@ -53,12 +78,24 @@ nltk.download('punkt', quiet=True);"
# Set up application files
WORKDIR /app
# Enterprise Version Files
COPY ./ee /app/ee
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
# Set up application files
COPY ./danswer /app/danswer
COPY ./shared_configs /app/shared_configs
COPY ./alembic /app/alembic
COPY ./alembic.ini /app/alembic.ini
COPY supervisord.conf /usr/etc/supervisord.conf
# Escape hatch
COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
# Put logo in assets
COPY ./assets /app/assets
ENV PYTHONPATH /app
# Default command which does nothing

View File

@@ -0,0 +1,31 @@
"""Add thread specific model selection
Revision ID: 0568ccf46a6b
Revises: e209dc5a8156
Create Date: 2024-06-19 14:25:36.376046
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0568ccf46a6b"
down_revision = "e209dc5a8156"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"chat_session",
sa.Column("current_alternate_model", sa.String(), nullable=True),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("chat_session", "current_alternate_model")
# ### end Alembic commands ###

View File

@@ -0,0 +1,32 @@
"""add search doc relevance details
Revision ID: 05c07bf07c00
Revises: b896bbd0d5a7
Create Date: 2024-07-10 17:48:15.886653
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "05c07bf07c00"
down_revision = "b896bbd0d5a7"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"search_doc",
sa.Column("is_relevant", sa.Boolean(), nullable=True),
)
op.add_column(
"search_doc",
sa.Column("relevance_explanation", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("search_doc", "relevance_explanation")
op.drop_column("search_doc", "is_relevant")

View File

@@ -0,0 +1,86 @@
"""remove-feedback-foreignkey-constraint
Revision ID: 23957775e5f5
Revises: bc9771dccadf
Create Date: 2024-06-27 16:04:51.480437
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "23957775e5f5"
down_revision = "bc9771dccadf"
branch_labels = None # type: ignore
depends_on = None # type: ignore
def upgrade() -> None:
op.drop_constraint(
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
)
op.create_foreign_key(
"chat_feedback__chat_message_fk",
"chat_feedback",
"chat_message",
["chat_message_id"],
["id"],
ondelete="SET NULL",
)
op.alter_column(
"chat_feedback", "chat_message_id", existing_type=sa.Integer(), nullable=True
)
op.drop_constraint(
"document_retrieval_feedback__chat_message_fk",
"document_retrieval_feedback",
type_="foreignkey",
)
op.create_foreign_key(
"document_retrieval_feedback__chat_message_fk",
"document_retrieval_feedback",
"chat_message",
["chat_message_id"],
["id"],
ondelete="SET NULL",
)
op.alter_column(
"document_retrieval_feedback",
"chat_message_id",
existing_type=sa.Integer(),
nullable=True,
)
def downgrade() -> None:
op.alter_column(
"chat_feedback", "chat_message_id", existing_type=sa.Integer(), nullable=False
)
op.drop_constraint(
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
)
op.create_foreign_key(
"chat_feedback__chat_message_fk",
"chat_feedback",
"chat_message",
["chat_message_id"],
["id"],
)
op.alter_column(
"document_retrieval_feedback",
"chat_message_id",
existing_type=sa.Integer(),
nullable=False,
)
op.drop_constraint(
"document_retrieval_feedback__chat_message_fk",
"document_retrieval_feedback",
type_="foreignkey",
)
op.create_foreign_key(
"document_retrieval_feedback__chat_message_fk",
"document_retrieval",
"chat_message",
["chat_message_id"],
["id"],
)

View File

@@ -0,0 +1,38 @@
"""add alternate assistant to chat message
Revision ID: 3a7802814195
Revises: 23957775e5f5
Create Date: 2024-06-05 11:18:49.966333
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "3a7802814195"
down_revision = "23957775e5f5"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True)
)
op.create_foreign_key(
"fk_chat_message_persona",
"chat_message",
"persona",
["alternate_assistant_id"],
["id"],
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("fk_chat_message_persona", "chat_message", type_="foreignkey")
op.drop_column("chat_message", "alternate_assistant_id")

View File

@@ -0,0 +1,65 @@
"""add cloud embedding model and update embedding_model
Revision ID: 44f856ae2a4a
Revises: d716b0791ddd
Create Date: 2024-06-28 20:01:05.927647
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "44f856ae2a4a"
down_revision = "d716b0791ddd"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# Create embedding_provider table
op.create_table(
"embedding_provider",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("api_key", sa.LargeBinary(), nullable=True),
sa.Column("default_model_id", sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
# Add cloud_provider_id to embedding_model table
op.add_column(
"embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True)
)
# Add foreign key constraints
op.create_foreign_key(
"fk_embedding_model_cloud_provider",
"embedding_model",
"embedding_provider",
["cloud_provider_id"],
["id"],
)
op.create_foreign_key(
"fk_embedding_provider_default_model",
"embedding_provider",
"embedding_model",
["default_model_id"],
["id"],
)
def downgrade() -> None:
# Remove foreign key constraints
op.drop_constraint(
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
)
op.drop_constraint(
"fk_embedding_provider_default_model", "embedding_provider", type_="foreignkey"
)
# Remove cloud_provider_id column
op.drop_column("embedding_model", "cloud_provider_id")
# Drop embedding_provider table
op.drop_table("embedding_provider")

View File

@@ -0,0 +1,23 @@
"""added is_internet to DBDoc
Revision ID: 4505fd7302e1
Revises: c18cdf4b497e
Create Date: 2024-06-18 20:46:09.095034
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4505fd7302e1"
down_revision = "c18cdf4b497e"
def upgrade() -> None:
op.add_column("search_doc", sa.Column("is_internet", sa.Boolean(), nullable=True))
op.add_column("tool", sa.Column("display_name", sa.String(), nullable=True))
def downgrade() -> None:
op.drop_column("tool", "display_name")
op.drop_column("search_doc", "is_internet")

View File

@@ -0,0 +1,61 @@
"""Add support for custom tools
Revision ID: 48d14957fe80
Revises: b85f02ec1308
Create Date: 2024-06-09 14:58:19.946509
"""
from alembic import op
import fastapi_users_db_sqlalchemy
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "48d14957fe80"
down_revision = "b85f02ec1308"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"tool",
sa.Column(
"openapi_schema",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
op.add_column(
"tool",
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
)
op.create_foreign_key("tool_user_fk", "tool", "user", ["user_id"], ["id"])
op.create_table(
"tool_call",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("tool_id", sa.Integer(), nullable=False),
sa.Column("tool_name", sa.String(), nullable=False),
sa.Column(
"tool_arguments", postgresql.JSONB(astext_type=sa.Text()), nullable=False
),
sa.Column(
"tool_result", postgresql.JSONB(astext_type=sa.Text()), nullable=False
),
sa.Column(
"message_id", sa.Integer(), sa.ForeignKey("chat_message.id"), nullable=False
),
)
def downgrade() -> None:
op.drop_table("tool_call")
op.drop_constraint("tool_user_fk", "tool", type_="foreignkey")
op.drop_column("tool", "user_id")
op.drop_column("tool", "openapi_schema")

View File

@@ -0,0 +1,35 @@
"""added slack_auto_filter
Revision ID: 7aea705850d5
Revises: 4505fd7302e1
Create Date: 2024-07-10 11:01:23.581015
"""
from alembic import op
import sqlalchemy as sa
revision = "7aea705850d5"
down_revision = "4505fd7302e1"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"slack_bot_config",
sa.Column("enable_auto_filters", sa.Boolean(), nullable=True),
)
op.execute(
"UPDATE slack_bot_config SET enable_auto_filters = FALSE WHERE enable_auto_filters IS NULL"
)
op.alter_column(
"slack_bot_config",
"enable_auto_filters",
existing_type=sa.Boolean(),
nullable=False,
server_default=sa.false(),
)
def downgrade() -> None:
op.drop_column("slack_bot_config", "enable_auto_filters")

View File

@@ -0,0 +1,27 @@
"""Add chosen_assistants to User table
Revision ID: a3bfd0d64902
Revises: ec85f2b3c544
Create Date: 2024-05-26 17:22:24.834741
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "a3bfd0d64902"
down_revision = "ec85f2b3c544"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column("chosen_assistants", postgresql.ARRAY(sa.Integer()), nullable=True),
)
def downgrade() -> None:
op.drop_column("user", "chosen_assistants")

View File

@@ -0,0 +1,28 @@
"""fix-file-type-migration
Revision ID: b85f02ec1308
Revises: a3bfd0d64902
Create Date: 2024-05-31 18:09:26.658164
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "b85f02ec1308"
down_revision = "a3bfd0d64902"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.execute(
"""
UPDATE file_store
SET file_origin = UPPER(file_origin)
"""
)
def downgrade() -> None:
# Let's not break anything on purpose :)
pass

View File

@@ -0,0 +1,23 @@
"""backfill is_internet data to False
Revision ID: b896bbd0d5a7
Revises: 44f856ae2a4a
Create Date: 2024-07-16 15:21:05.718571
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "b896bbd0d5a7"
down_revision = "44f856ae2a4a"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.execute("UPDATE search_doc SET is_internet = FALSE WHERE is_internet IS NULL")
def downgrade() -> None:
pass

View File

@@ -0,0 +1,51 @@
"""create usage reports table
Revision ID: bc9771dccadf
Revises: 0568ccf46a6b
Create Date: 2024-06-18 10:04:26.800282
"""
from alembic import op
import sqlalchemy as sa
import fastapi_users_db_sqlalchemy
# revision identifiers, used by Alembic.
revision = "bc9771dccadf"
down_revision = "0568ccf46a6b"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"usage_reports",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("report_name", sa.String(), nullable=False),
sa.Column(
"requestor_user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("period_from", sa.DateTime(timezone=True), nullable=True),
sa.Column("period_to", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(
["report_name"],
["file_store.file_name"],
),
sa.ForeignKeyConstraint(
["requestor_user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
def downgrade() -> None:
op.drop_table("usage_reports")

View File

@@ -0,0 +1,75 @@
"""Add standard_answer tables
Revision ID: c18cdf4b497e
Revises: 3a7802814195
Create Date: 2024-06-06 15:15:02.000648
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c18cdf4b497e"
down_revision = "3a7802814195"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"standard_answer",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("keyword", sa.String(), nullable=False),
sa.Column("answer", sa.String(), nullable=False),
sa.Column("active", sa.Boolean(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("keyword"),
)
op.create_table(
"standard_answer_category",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
op.create_table(
"standard_answer__standard_answer_category",
sa.Column("standard_answer_id", sa.Integer(), nullable=False),
sa.Column("standard_answer_category_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["standard_answer_category_id"],
["standard_answer_category.id"],
),
sa.ForeignKeyConstraint(
["standard_answer_id"],
["standard_answer.id"],
),
sa.PrimaryKeyConstraint("standard_answer_id", "standard_answer_category_id"),
)
op.create_table(
"slack_bot_config__standard_answer_category",
sa.Column("slack_bot_config_id", sa.Integer(), nullable=False),
sa.Column("standard_answer_category_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["slack_bot_config_id"],
["slack_bot_config.id"],
),
sa.ForeignKeyConstraint(
["standard_answer_category_id"],
["standard_answer_category.id"],
),
sa.PrimaryKeyConstraint("slack_bot_config_id", "standard_answer_category_id"),
)
op.add_column(
"chat_session", sa.Column("slack_thread_id", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("chat_session", "slack_thread_id")
op.drop_table("slack_bot_config__standard_answer_category")
op.drop_table("standard_answer__standard_answer_category")
op.drop_table("standard_answer_category")
op.drop_table("standard_answer")

View File

@@ -0,0 +1,45 @@
"""combined slack id fields
Revision ID: d716b0791ddd
Revises: 7aea705850d5
Create Date: 2024-07-10 17:57:45.630550
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "d716b0791ddd"
down_revision = "7aea705850d5"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.execute(
"""
UPDATE slack_bot_config
SET channel_config = jsonb_set(
channel_config,
'{respond_member_group_list}',
coalesce(channel_config->'respond_team_member_list', '[]'::jsonb) ||
coalesce(channel_config->'respond_slack_group_list', '[]'::jsonb)
) - 'respond_team_member_list' - 'respond_slack_group_list'
"""
)
def downgrade() -> None:
op.execute(
"""
UPDATE slack_bot_config
SET channel_config = jsonb_set(
jsonb_set(
channel_config - 'respond_member_group_list',
'{respond_team_member_list}',
'[]'::jsonb
),
'{respond_slack_group_list}',
'[]'::jsonb
)
"""
)

View File

@@ -0,0 +1,22 @@
"""added-prune-frequency
Revision ID: e209dc5a8156
Revises: 48d14957fe80
Create Date: 2024-06-16 16:02:35.273231
"""
from alembic import op
import sqlalchemy as sa
revision = "e209dc5a8156"
down_revision = "48d14957fe80"
branch_labels = None # type: ignore
depends_on = None # type: ignore
def upgrade() -> None:
op.add_column("connector", sa.Column("prune_freq", sa.Integer(), nullable=True))
def downgrade() -> None:
op.drop_column("connector", "prune_freq")

View File

@@ -10,7 +10,7 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "ec85f2b3c544"
down_revision = "3879338f8ba1"
down_revision = "70f00c45c0f2"
branch_labels: None = None
depends_on: None = None

2
backend/assets/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
*
!.gitignore

View File

@@ -1,6 +1,7 @@
from sqlalchemy.orm import Session
from danswer.access.models import DocumentAccess
from danswer.access.utils import prefix_user
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.db.document import get_acccess_info_for_documents
from danswer.db.models import User
@@ -19,7 +20,7 @@ def _get_access_for_documents(
cc_pair_to_delete=cc_pair_to_delete,
)
return {
document_id: DocumentAccess.build(user_ids, is_public)
document_id: DocumentAccess.build(user_ids, [], is_public)
for document_id, user_ids, is_public in document_access_info
}
@@ -38,12 +39,6 @@ def get_access_for_documents(
) # type: ignore
def prefix_user(user_id: str) -> str:
"""Prefixes a user ID to eliminate collision with group names.
This assumes that groups are prefixed with a different prefix."""
return f"user_id:{user_id}"
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
"""Returns a list of ACL entries that the user has access to. This is meant to be
used downstream to filter out documents that the user does not have access to. The

View File

@@ -1,20 +1,30 @@
from dataclasses import dataclass
from uuid import UUID
from danswer.access.utils import prefix_user
from danswer.access.utils import prefix_user_group
from danswer.configs.constants import PUBLIC_DOC_PAT
@dataclass(frozen=True)
class DocumentAccess:
user_ids: set[str] # stringified UUIDs
user_groups: set[str] # names of user groups associated with this document
is_public: bool
def to_acl(self) -> list[str]:
return list(self.user_ids) + ([PUBLIC_DOC_PAT] if self.is_public else [])
return (
[prefix_user(user_id) for user_id in self.user_ids]
+ [prefix_user_group(group_name) for group_name in self.user_groups]
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
)
@classmethod
def build(cls, user_ids: list[UUID | None], is_public: bool) -> "DocumentAccess":
def build(
cls, user_ids: list[UUID | None], user_groups: list[str], is_public: bool
) -> "DocumentAccess":
return cls(
user_ids={str(user_id) for user_id in user_ids if user_id},
user_groups=set(user_groups),
is_public=is_public,
)

View File

@@ -0,0 +1,10 @@
def prefix_user(user_id: str) -> str:
"""Prefixes a user ID to eliminate collision with group names.
This assumes that groups are prefixed with a different prefix."""
return f"user_id:{user_id}"
def prefix_user_group(user_group_name: str) -> str:
"""Prefixes a user group name to eliminate collision with user IDs.
This assumes that user ids are prefixed with a different prefix."""
return f"group:{user_group_name}"

View File

@@ -0,0 +1,21 @@
from typing import cast
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.dynamic_configs.interface import JSON_ro
USER_STORE_KEY = "INVITED_USERS"
def get_invited_users() -> list[str]:
try:
store = get_dynamic_config_store()
return cast(list, store.load(USER_STORE_KEY))
except ConfigNotFoundError:
return list()
def write_invited_users(emails: list[str]) -> int:
store = get_dynamic_config_store()
store.store(USER_STORE_KEY, cast(JSON_ro, emails))
return len(emails)

View File

@@ -0,0 +1,40 @@
from collections.abc import Mapping
from typing import Any
from typing import cast
from danswer.auth.schemas import UserRole
from danswer.dynamic_configs.store import ConfigNotFoundError
from danswer.dynamic_configs.store import DynamicConfigStore
from danswer.server.manage.models import UserInfo
from danswer.server.manage.models import UserPreferences
NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
def set_no_auth_user_preferences(
store: DynamicConfigStore, preferences: UserPreferences
) -> None:
store.store(NO_AUTH_USER_PREFERENCES_KEY, preferences.dict())
def load_no_auth_user_preferences(store: DynamicConfigStore) -> UserPreferences:
try:
preferences_data = cast(
Mapping[str, Any], store.load(NO_AUTH_USER_PREFERENCES_KEY)
)
return UserPreferences(**preferences_data)
except ConfigNotFoundError:
return UserPreferences(chosen_assistants=None)
def fetch_no_auth_user(store: DynamicConfigStore) -> UserInfo:
return UserInfo(
id="__no_auth_user__",
email="anonymous@danswer.ai",
is_active=True,
is_superuser=False,
is_verified=True,
role=UserRole.ADMIN,
preferences=load_no_auth_user_preferences(store),
)

View File

@@ -9,6 +9,12 @@ class UserRole(str, Enum):
ADMIN = "admin"
class UserStatus(str, Enum):
LIVE = "live"
INVITED = "invited"
DEACTIVATED = "deactivated"
class UserRead(schemas.BaseUser[uuid.UUID]):
role: UserRole

View File

@@ -1,4 +1,3 @@
import os
import smtplib
import uuid
from collections.abc import AsyncGenerator
@@ -27,6 +26,7 @@ from fastapi_users.openapi import OpenAPIResponseType
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
from sqlalchemy.orm import Session
from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
from danswer.configs.app_configs import AUTH_TYPE
@@ -46,6 +46,7 @@ from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.engine import get_session
@@ -54,14 +55,13 @@ from danswer.db.models import User
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import (
fetch_versioned_implementation,
)
logger = setup_logger()
USER_WHITELIST_FILE = "/home/danswer_whitelist.txt"
_user_whitelist: list[str] | None = None
def verify_auth_setting() -> None:
if AUTH_TYPE not in [AuthType.DISABLED, AuthType.BASIC, AuthType.GOOGLE_OAUTH]:
@@ -92,20 +92,8 @@ def user_needs_to_be_verified() -> bool:
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
def get_user_whitelist() -> list[str]:
global _user_whitelist
if _user_whitelist is None:
if os.path.exists(USER_WHITELIST_FILE):
with open(USER_WHITELIST_FILE, "r") as file:
_user_whitelist = [line.strip() for line in file]
else:
_user_whitelist = []
return _user_whitelist
def verify_email_in_whitelist(email: str) -> None:
whitelist = get_user_whitelist()
whitelist = get_invited_users()
if (whitelist and email not in whitelist) or not email:
raise PermissionError("User not on allowed user whitelist")
@@ -163,7 +151,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
verify_email_domain(user_create.email)
if hasattr(user_create, "role"):
user_count = await get_user_count()
if user_count == 0:
if user_count == 0 or user_create.email in get_default_admin_user_emails():
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC

View File

@@ -4,13 +4,22 @@ from typing import cast
from celery import Celery # type: ignore
from sqlalchemy.orm import Session
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.celery_utils import should_prune_cc_pair
from danswer.background.celery.celery_utils import should_sync_doc_set
from danswer.background.connector_deletion import delete_connector_credential_pair
from danswer.background.connector_deletion import delete_connector_credential_pair_batch
from danswer.background.task_utils import build_celery_task_wrapper
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import InputType
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import delete_document_set
from danswer.db.document_set import fetch_document_sets
@@ -22,8 +31,6 @@ from danswer.db.engine import build_connection_string
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import SYNC_DB_API
from danswer.db.models import DocumentSet
from danswer.db.tasks import check_live_task_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import UpdateRequest
@@ -90,6 +97,74 @@ def cleanup_connector_credential_pair_task(
raise e
@build_celery_task_wrapper(name_cc_prune_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def prune_documents_task(connector_id: int, credential_id: int) -> None:
"""connector pruning task. For a cc pair, this task pulls all docuement IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
with Session(get_sqlalchemy_engine()) as db_session:
try:
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair:
logger.warning(f"ccpair not found for {connector_id} {credential_id}")
return
runnable_connector = instantiate_connector(
cc_pair.connector.source,
InputType.PRUNE,
cc_pair.connector.connector_specific_config,
cc_pair.credential,
db_session,
)
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
runnable_connector
)
all_indexed_document_ids = {
doc.id
for doc in get_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
}
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
if len(doc_ids_to_remove) == 0:
logger.info(
f"No docs to prune from {cc_pair.connector.source} connector"
)
return
logger.info(
f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector"
)
delete_connector_credential_pair_batch(
document_ids=doc_ids_to_remove,
connector_id=connector_id,
credential_id=credential_id,
document_index=document_index,
)
except Exception as e:
logger.exception(
f"Failed to run pruning for connector id {connector_id} due to {e}"
)
raise e
@build_celery_task_wrapper(name_document_set_sync_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_document_set_task(document_set_id: int) -> None:
@@ -177,32 +252,48 @@ def sync_document_set_task(document_set_id: int) -> None:
soft_time_limit=JOB_TIMEOUT,
)
def check_for_document_sets_sync_task() -> None:
"""Runs periodically to check if any document sets are out of sync
Creates a task to sync the set if needed"""
"""Runs periodically to check if any sync tasks should be run and adds them
to the queue"""
with Session(get_sqlalchemy_engine()) as db_session:
# check if any document sets are not synced
document_set_info = fetch_document_sets(
user_id=None, db_session=db_session, include_outdated=True
)
for document_set, _ in document_set_info:
if not document_set.is_up_to_date:
task_name = name_document_set_sync_task(document_set.id)
latest_sync = get_latest_task(task_name, db_session)
if latest_sync and check_live_task_not_timed_out(
latest_sync, db_session
):
logger.info(
f"Document set '{document_set.id}' is already syncing. Skipping."
)
continue
logger.info(f"Document set {document_set.id} syncing now!")
if should_sync_doc_set(document_set, db_session):
logger.info(f"Syncing the {document_set.name} document set")
sync_document_set_task.apply_async(
kwargs=dict(document_set_id=document_set.id),
)
@celery_app.task(
name="check_for_prune_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_prune_task() -> None:
"""Runs periodically to check if any prune tasks should be run and adds them
to the queue"""
with Session(get_sqlalchemy_engine()) as db_session:
all_cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in all_cc_pairs:
if should_prune_cc_pair(
connector=cc_pair.connector,
credential=cc_pair.credential,
db_session=db_session,
):
logger.info(f"Pruning the {cc_pair.connector.name} connector")
prune_documents_task.apply_async(
kwargs=dict(
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
)
#####
# Celery Beat (Periodic Tasks) Settings
#####
@@ -212,3 +303,11 @@ celery_app.conf.beat_schedule = {
"schedule": timedelta(seconds=5),
},
}
celery_app.conf.beat_schedule.update(
{
"check-for-prune": {
"task": "check_for_prune_task",
"schedule": timedelta(seconds=5),
},
}
)

View File

@@ -0,0 +1,9 @@
"""Entry point for running celery worker / celery beat."""
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
celery_app = fetch_versioned_implementation(
"danswer.background.celery.celery_app", "celery_app"
)

View File

@@ -1,8 +1,32 @@
from datetime import datetime
from datetime import timezone
from sqlalchemy.orm import Session
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.configs.app_configs import PREVENT_SIMULTANEOUS_PRUNING
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from danswer.connectors.interfaces import BaseConnector
from danswer.connectors.interfaces import IdConnector
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import Document
from danswer.db.engine import get_db_current_time
from danswer.db.models import Connector
from danswer.db.models import Credential
from danswer.db.models import DocumentSet
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.db.tasks import get_latest_task_by_type
from danswer.server.documents.models import DeletionAttemptSnapshot
from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_deletion_status(
@@ -21,3 +45,94 @@ def get_deletion_status(
credential_id=credential_id,
status=task_state.status,
)
def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool:
if document_set.is_up_to_date:
return False
task_name = name_document_set_sync_task(document_set.id)
latest_sync = get_latest_task(task_name, db_session)
if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session):
logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.")
return False
logger.info(f"Document set {document_set.id} syncing now!")
return True
def should_prune_cc_pair(
connector: Connector, credential: Credential, db_session: Session
) -> bool:
if not connector.prune_freq:
return False
pruning_task_name = name_cc_prune_task(
connector_id=connector.id, credential_id=credential.id
)
last_pruning_task = get_latest_task(pruning_task_name, db_session)
current_db_time = get_db_current_time(db_session)
if not last_pruning_task:
time_since_initialization = current_db_time - connector.time_created
if time_since_initialization.total_seconds() >= connector.prune_freq:
return True
return False
if PREVENT_SIMULTANEOUS_PRUNING:
pruning_type_task_name = name_cc_prune_task()
last_pruning_type_task = get_latest_task_by_type(
pruning_type_task_name, db_session
)
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
last_pruning_type_task, db_session
):
logger.info("Another Connector is already pruning. Skipping.")
return False
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
logger.info(f"Connector '{connector.name}' is already pruning. Skipping.")
return False
if not last_pruning_task.start_time:
return False
time_since_last_pruning = current_db_time - last_pruning_task.start_time
return time_since_last_pruning.total_seconds() >= connector.prune_freq
def document_batch_to_ids(doc_batch: list[Document]) -> set[str]:
return {doc.id for doc in doc_batch}
def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> set[str]:
"""
If the PruneConnector hasnt been implemented for the given connector, just pull
all docs using the load_from_state and grab out the IDs
"""
all_connector_doc_ids: set[str] = set()
doc_batch_generator = None
if isinstance(runnable_connector, IdConnector):
all_connector_doc_ids = runnable_connector.retrieve_all_source_ids()
elif isinstance(runnable_connector, LoadConnector):
doc_batch_generator = runnable_connector.load_from_state()
elif isinstance(runnable_connector, PollConnector):
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime.now(timezone.utc).timestamp()
doc_batch_generator = runnable_connector.poll_source(start=start, end=end)
else:
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
if doc_batch_generator:
doc_batch_processing_func = document_batch_to_ids
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
doc_batch_processing_func = rate_limit_builder(
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
)(document_batch_to_ids)
for doc_batch in doc_batch_generator:
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
return all_connector_doc_ids

View File

@@ -41,7 +41,7 @@ logger = setup_logger()
_DELETION_BATCH_SIZE = 1000
def _delete_connector_credential_pair_batch(
def delete_connector_credential_pair_batch(
document_ids: list[str],
connector_id: int,
credential_id: int,
@@ -169,7 +169,7 @@ def delete_connector_credential_pair(
if not documents:
break
_delete_connector_credential_pair_batch(
delete_connector_credential_pair_batch(
document_ids=[document.id for document in documents],
connector_id=connector_id,
credential_id=credential_id,

View File

@@ -105,7 +105,9 @@ class SimpleJobClient:
"""NOTE: `pure` arg is needed so this can be a drop in replacement for Dask"""
self._cleanup_completed_jobs()
if len(self.jobs) >= self.n_workers:
logger.debug("No available workers to run job")
logger.debug(
f"No available workers to run job. Currently running '{len(self.jobs)}' jobs, with a limit of '{self.n_workers}'."
)
return None
job_id = self.job_id_counter

View File

@@ -6,11 +6,7 @@ from datetime import timezone
from sqlalchemy.orm import Session
from danswer.background.connector_deletion import (
_delete_connector_credential_pair_batch,
)
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
from danswer.configs.app_configs import DISABLE_DOCUMENT_CLEANUP
from danswer.configs.app_configs import POLL_CONNECTOR_OFFSET
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.interfaces import GenerateDocumentsOutput
@@ -21,8 +17,6 @@ from danswer.connectors.models import InputType
from danswer.db.connector import disable_connector
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.credentials import backend_update_credential_json
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
@@ -37,6 +31,7 @@ from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import IndexAttemptSingleton
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
logger = setup_logger()
@@ -46,7 +41,7 @@ def _get_document_generator(
attempt: IndexAttempt,
start_time: datetime,
end_time: datetime,
) -> tuple[GenerateDocumentsOutput, bool]:
) -> GenerateDocumentsOutput:
"""
NOTE: `start_time` and `end_time` are only used for poll connectors
@@ -57,16 +52,13 @@ def _get_document_generator(
task = attempt.connector.input_type
try:
runnable_connector, new_credential_json = instantiate_connector(
runnable_connector = instantiate_connector(
attempt.connector.source,
task,
attempt.connector.connector_specific_config,
attempt.credential.credential_json,
attempt.credential,
db_session,
)
if new_credential_json is not None:
backend_update_credential_json(
attempt.credential, new_credential_json, db_session
)
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")
disable_connector(attempt.connector.id, db_session)
@@ -75,7 +67,7 @@ def _get_document_generator(
if task == InputType.LOAD_STATE:
assert isinstance(runnable_connector, LoadConnector)
doc_batch_generator = runnable_connector.load_from_state()
is_listing_complete = True
elif task == InputType.POLL:
assert isinstance(runnable_connector, PollConnector)
if attempt.connector_id is None or attempt.credential_id is None:
@@ -88,13 +80,12 @@ def _get_document_generator(
doc_batch_generator = runnable_connector.poll_source(
start=start_time.timestamp(), end=end_time.timestamp()
)
is_listing_complete = False
else:
# Event types cannot be handled by a background type
raise RuntimeError(f"Invalid task type: {task}")
return doc_batch_generator, is_listing_complete
return doc_batch_generator
def _run_indexing(
@@ -107,7 +98,6 @@ def _run_indexing(
3. Updates Postgres to record the indexed documents + the outcome of this run
"""
start_time = time.time()
db_embedding_model = index_attempt.embedding_model
index_name = db_embedding_model.index_name
@@ -125,6 +115,8 @@ def _run_indexing(
normalize=db_embedding_model.normalize,
query_prefix=db_embedding_model.query_prefix,
passage_prefix=db_embedding_model.passage_prefix,
api_key=db_embedding_model.api_key,
provider_type=db_embedding_model.provider_type,
)
indexing_pipeline = build_indexing_pipeline(
@@ -166,7 +158,7 @@ def _run_indexing(
datetime(1970, 1, 1, tzinfo=timezone.utc),
)
doc_batch_generator, is_listing_complete = _get_document_generator(
doc_batch_generator = _get_document_generator(
db_session=db_session,
attempt=index_attempt,
start_time=window_start,
@@ -224,39 +216,6 @@ def _run_indexing(
docs_removed_from_index=0,
)
if is_listing_complete and not DISABLE_DOCUMENT_CLEANUP:
# clean up all documents from the index that have not been returned from the connector
all_indexed_document_ids = {
d.id
for d in get_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
)
}
doc_ids_to_remove = list(
all_indexed_document_ids - all_connector_doc_ids
)
logger.debug(
f"Cleaning up {len(doc_ids_to_remove)} documents that are not contained in the newest connector state"
)
# delete docs from cc-pair and receive the number of completely deleted docs in return
_delete_connector_credential_pair_batch(
document_ids=doc_ids_to_remove,
connector_id=db_connector.id,
credential_id=db_credential.id,
document_index=document_index,
)
update_docs_indexed(
db_session=db_session,
index_attempt=index_attempt,
total_docs_indexed=document_count,
new_docs_indexed=net_doc_change,
docs_removed_from_index=len(doc_ids_to_remove),
)
run_end_dt = window_end
if is_primary:
update_connector_credential_pair(
@@ -329,6 +288,7 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
db_session=db_session,
index_attempt_id=index_attempt_id,
)
if attempt is None:
raise RuntimeError(f"Unable to find IndexAttempt for ID '{index_attempt_id}'")
@@ -346,11 +306,14 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
return attempt
def run_indexing_entrypoint(index_attempt_id: int) -> None:
def run_indexing_entrypoint(index_attempt_id: int, is_ee: bool = False) -> None:
"""Entrypoint for indexing run when using dask distributed.
Wraps the actual logic in a `try` block so that we can catch any exceptions
and mark the attempt as failed."""
try:
if is_ee:
global_version.set_ee()
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)

View File

@@ -22,6 +22,15 @@ def name_document_set_sync_task(document_set_id: int) -> str:
return f"sync_doc_set_{document_set_id}"
def name_cc_prune_task(
connector_id: int | None = None, credential_id: int | None = None
) -> str:
task_name = f"prune_connector_credential_pair_{connector_id}_{credential_id}"
if not connector_id or not credential_id:
task_name = "prune_connector_credential_pair"
return task_name
T = TypeVar("T", bound=Callable)

View File

@@ -35,6 +35,8 @@ from danswer.db.models import IndexModelStatus
from danswer.db.swap_index import check_index_swap
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import LOG_LEVEL
from shared_configs.configs import MODEL_SERVER_PORT
@@ -307,10 +309,18 @@ def kickoff_indexing_jobs(
if use_secondary_index:
run = secondary_client.submit(
run_indexing_entrypoint, attempt.id, pure=False
run_indexing_entrypoint,
attempt.id,
global_version.get_is_ee_version(),
pure=False,
)
else:
run = client.submit(run_indexing_entrypoint, attempt.id, pure=False)
run = client.submit(
run_indexing_entrypoint,
attempt.id,
global_version.get_is_ee_version(),
pure=False,
)
if run:
secondary_str = "(secondary index) " if use_secondary_index else ""
@@ -333,13 +343,15 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
logger.info("Running a first inference to warm up embedding model")
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=INDEXING_MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
if db_embedding_model.cloud_provider_id is None:
logger.info("Running a first inference to warm up embedding model")
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=INDEXING_MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
client_primary: Client | SimpleJobClient
client_secondary: Client | SimpleJobClient
@@ -398,6 +410,8 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
def update__main() -> None:
set_is_ee_based_on_env_variable()
logger.info("Starting Indexing Loop")
update_loop()

View File

@@ -1,5 +1,4 @@
import re
from collections.abc import Sequence
from typing import cast
from sqlalchemy.orm import Session
@@ -9,42 +8,30 @@ from danswer.chat.models import LlmDoc
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.models import ChatMessage
from danswer.llm.answering.models import PreviousMessage
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.utils.logger import setup_logger
logger = setup_logger()
def llm_doc_from_inference_section(inf_chunk: InferenceSection) -> LlmDoc:
def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDoc:
return LlmDoc(
document_id=inf_chunk.document_id,
document_id=inference_section.center_chunk.document_id,
# This one is using the combined content of all the chunks of the section
# In default settings, this is the same as just the content of base chunk
content=inf_chunk.combined_content,
blurb=inf_chunk.blurb,
semantic_identifier=inf_chunk.semantic_identifier,
source_type=inf_chunk.source_type,
metadata=inf_chunk.metadata,
updated_at=inf_chunk.updated_at,
link=inf_chunk.source_links[0] if inf_chunk.source_links else None,
source_links=inf_chunk.source_links,
content=inference_section.combined_content,
blurb=inference_section.center_chunk.blurb,
semantic_identifier=inference_section.center_chunk.semantic_identifier,
source_type=inference_section.center_chunk.source_type,
metadata=inference_section.center_chunk.metadata,
updated_at=inference_section.center_chunk.updated_at,
link=inference_section.center_chunk.source_links[0]
if inference_section.center_chunk.source_links
else None,
source_links=inference_section.center_chunk.source_links,
)
def map_document_id_order(
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
) -> dict[str, int]:
order_mapping = {}
current = 1 if one_indexed else 0
for chunk in chunks:
if chunk.document_id not in order_mapping:
order_mapping[chunk.document_id] = current
current += 1
return order_mapping
def create_chat_chain(
chat_session_id: int,
db_session: Session,

View File

@@ -1,18 +1,16 @@
from typing import cast
import yaml
from sqlalchemy.orm import Session
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import PERSONAS_YAML
from danswer.configs.chat_configs import PROMPTS_YAML
from danswer.db.chat import get_prompt_by_name
from danswer.db.chat import upsert_persona
from danswer.db.chat import upsert_prompt
from danswer.db.document_set import get_or_create_document_set_by_name
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.db.models import Prompt as PromptDBModel
from danswer.db.persona import get_prompt_by_name
from danswer.db.persona import upsert_persona
from danswer.db.persona import upsert_prompt
from danswer.search.enums import RecencyBiasSetting
@@ -50,7 +48,7 @@ def load_personas_from_yaml(
with Session(get_sqlalchemy_engine()) as db_session:
for persona in all_personas:
doc_set_names = persona["document_sets"]
doc_sets: list[DocumentSetDBModel] | None = [
doc_sets: list[DocumentSetDBModel] = [
get_or_create_document_set_by_name(db_session, name)
for name in doc_set_names
]
@@ -58,22 +56,24 @@ def load_personas_from_yaml(
# Assume if user hasn't set any document sets for the persona, the user may want
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
# the document sets for the persona
if not doc_sets:
doc_sets = None
prompt_set_names = persona["prompts"]
if not prompt_set_names:
prompts: list[PromptDBModel | None] | None = None
doc_set_ids: list[int] | None = None
if doc_sets:
doc_set_ids = [doc_set.id for doc_set in doc_sets]
else:
prompts = [
doc_set_ids = None
prompt_ids: list[int] | None = None
prompt_set_names = persona["prompts"]
if prompt_set_names:
prompts: list[PromptDBModel | None] = [
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
for prompt_name in prompt_set_names
]
if any([prompt is None for prompt in prompts]):
raise ValueError("Invalid Persona configs, not all prompts exist")
if not prompts:
prompts = None
if prompts:
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
p_id = persona.get("id")
upsert_persona(
@@ -91,8 +91,8 @@ def load_personas_from_yaml(
llm_model_provider_override=None,
llm_model_version_override=None,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
prompts=cast(list[PromptDBModel] | None, prompts),
document_sets=doc_sets,
prompt_ids=prompt_ids,
document_set_ids=doc_set_ids,
default_persona=True,
is_public=True,
db_session=db_session,

View File

@@ -42,11 +42,21 @@ class QADocsResponse(RetrievalDocs):
return initial_dict
# Second chunk of info for streaming QA
class LLMRelevanceFilterResponse(BaseModel):
relevant_chunk_indices: list[int]
class RelevanceChunk(BaseModel):
# TODO make this document level. Also slight misnomer here as this is actually
# done at the section level currently rather than the chunk
relevant: bool | None = None
content: str | None = None
class LLMRelevanceSummaryResponse(BaseModel):
relevance_summaries: dict[str, RelevanceChunk]
class DanswerAnswerPiece(BaseModel):
# A small piece of a complete answer. Used for streaming back answers.
answer_piece: str | None # if None, specifies the end of an Answer
@@ -106,12 +116,18 @@ class ImageGenerationDisplay(BaseModel):
file_ids: list[str]
class CustomToolResponse(BaseModel):
response: dict
tool_name: str
AnswerQuestionPossibleReturn = (
DanswerAnswerPiece
| DanswerQuotes
| CitationInfo
| DanswerContexts
| ImageGenerationDisplay
| CustomToolResponse
| StreamingError
)

View File

@@ -7,15 +7,18 @@ from sqlalchemy.orm import Session
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.models import CitationInfo
from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import ImageGenerationDisplay
from danswer.chat.models import LlmDoc
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.chat import attach_files_to_chat_message
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
@@ -30,7 +33,9 @@ from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
from danswer.db.persona import get_persona_by_id
from danswer.document_index.factory import get_default_document_index
from danswer.file_store.models import ChatFileType
from danswer.file_store.models import FileDescriptor
@@ -43,25 +48,44 @@ from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_llm_for_persona
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.search.enums import OptionalSearchSetting
from danswer.search.retrieval.search_runner import inference_documents_from_ids
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import InferenceSection
from danswer.search.retrieval.search_runner import inference_sections_from_ids
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
from danswer.search.utils import drop_llm_indices
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
from danswer.tools.factory import get_tool_cls
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from danswer.tools.custom.custom_tool import CustomToolCallSummary
from danswer.tools.force import ForceUseTool
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
from danswer.tools.images.image_generation_tool import ImageGenerationTool
from danswer.tools.internet_search.internet_search_tool import (
INTERNET_SEARCH_RESPONSE_ID,
)
from danswer.tools.internet_search.internet_search_tool import (
internet_search_response_to_search_docs,
)
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.utils import compute_all_tool_tokens
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.logger import setup_logger
@@ -94,14 +118,21 @@ def _handle_search_tool_response_summary(
packet: ToolResponse,
db_session: Session,
selected_search_docs: list[DbSearchDoc] | None,
) -> tuple[QADocsResponse, list[DbSearchDoc]]:
dedupe_docs: bool = False,
) -> tuple[QADocsResponse, list[DbSearchDoc], list[int] | None]:
response_sumary = cast(SearchResponseSummary, packet.response)
dropped_inds = None
if not selected_search_docs:
top_docs = chunks_or_sections_to_search_docs(response_sumary.top_sections)
deduped_docs = top_docs
if dedupe_docs:
deduped_docs, dropped_inds = dedupe_documents(top_docs)
reference_db_search_docs = [
create_db_search_doc(server_search_doc=top_doc, db_session=db_session)
for top_doc in top_docs
create_db_search_doc(server_search_doc=doc, db_session=db_session)
for doc in deduped_docs
]
else:
reference_db_search_docs = selected_search_docs
@@ -121,12 +152,48 @@ def _handle_search_tool_response_summary(
recency_bias_multiplier=response_sumary.recency_bias_multiplier,
),
reference_db_search_docs,
dropped_inds,
)
def _handle_internet_search_tool_response_summary(
packet: ToolResponse,
db_session: Session,
) -> tuple[QADocsResponse, list[DbSearchDoc]]:
internet_search_response = cast(InternetSearchResponse, packet.response)
server_search_docs = internet_search_response_to_search_docs(
internet_search_response
)
reference_db_search_docs = [
create_db_search_doc(server_search_doc=doc, db_session=db_session)
for doc in server_search_docs
]
response_docs = [
translate_db_search_doc_to_server_search_doc(db_search_doc)
for db_search_doc in reference_db_search_docs
]
return (
QADocsResponse(
rephrased_query=internet_search_response.revised_query,
top_documents=response_docs,
predicted_flow=QueryFlow.QUESTION_ANSWER,
predicted_search=SearchType.HYBRID,
applied_source_filters=[],
applied_time_cutoff=None,
recency_bias_multiplier=1.0,
),
reference_db_search_docs,
)
def _check_should_force_search(
new_msg_req: CreateChatMessageRequest,
) -> ForceUseTool | None:
# If files are already provided, don't run the search tool
if new_msg_req.file_descriptors:
return None
if (
new_msg_req.query_override
or (
@@ -134,6 +201,7 @@ def _check_should_force_search(
and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS
)
or new_msg_req.search_doc_ids
or DISABLE_LLM_CHOOSE_SEARCH
):
args = (
{"query": new_msg_req.query_override}
@@ -146,7 +214,7 @@ def _check_should_force_search(
args = {"query": new_msg_req.message}
return ForceUseTool(
tool_name=SearchTool.name(),
tool_name=SearchTool._NAME,
args=args,
)
return None
@@ -160,6 +228,7 @@ ChatPacket = (
| DanswerAnswerPiece
| CitationInfo
| ImageGenerationDisplay
| CustomToolResponse
)
ChatPacketStream = Iterator[ChatPacket]
@@ -177,6 +246,7 @@ def stream_chat_message_objects(
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
# user message (e.g. this can only be used for the chat-seeding flow).
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
) -> ChatPacketStream:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
@@ -199,7 +269,15 @@ def stream_chat_message_objects(
parent_id = new_msg_req.parent_message_id
reference_doc_ids = new_msg_req.search_doc_ids
retrieval_options = new_msg_req.retrieval_options
persona = chat_session.persona
alternate_assistant_id = new_msg_req.alternate_assistant_id
# use alternate persona if alternative assistant id is passed in
if alternate_assistant_id is not None:
persona = get_persona_by_id(
alternate_assistant_id, user=user, db_session=db_session
)
else:
persona = chat_session.persona
prompt_id = new_msg_req.prompt_id
if prompt_id is None and persona.prompts:
@@ -211,8 +289,10 @@ def stream_chat_message_objects(
)
try:
llm = get_llm_for_persona(
persona, new_msg_req.llm_override or chat_session.llm_override
llm, fast_llm = get_llms_for_persona(
persona=persona,
llm_override=new_msg_req.llm_override or chat_session.llm_override,
additional_headers=litellm_additional_headers,
)
except GenAIDisabledException:
raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.")
@@ -267,8 +347,8 @@ def stream_chat_message_objects(
"Be sure to update the chat pointers before calling this."
)
# Save now to save the latest chat message
db_session.commit()
# NOTE: do not commit user message - it will be committed when the
# assistant message is successfully generated
else:
# re-create linear history of messages
final_msg, history_msgs = create_chat_chain(
@@ -298,10 +378,11 @@ def stream_chat_message_objects(
new_file.to_file_descriptor() for new_file in latest_query_files
],
db_session=db_session,
commit=False,
)
selected_db_search_docs = None
selected_llm_docs: list[LlmDoc] | None = None
selected_sections: list[InferenceSection] | None = None
if reference_doc_ids:
identifier_tuples = get_doc_query_identifiers_from_model(
search_doc_ids=reference_doc_ids,
@@ -311,8 +392,8 @@ def stream_chat_message_objects(
)
# Generates full documents currently
# May extend to include chunk ranges
selected_llm_docs = inference_documents_from_ids(
# May extend to use sections instead in the future
selected_sections = inference_sections_from_ids(
doc_identifiers=identifier_tuples,
document_index=document_index,
)
@@ -353,77 +434,121 @@ def stream_chat_message_objects(
# rephrased_query=,
# token_count=,
message_type=MessageType.ASSISTANT,
alternate_assistant_id=new_msg_req.alternate_assistant_id,
# error=,
# reference_docs=,
db_session=db_session,
commit=True,
commit=False,
)
if not final_msg.prompt:
raise RuntimeError("No Prompt found")
prompt_config = PromptConfig.from_model(
final_msg.prompt,
prompt_override=(
new_msg_req.prompt_override or chat_session.prompt_override
),
prompt_config = (
PromptConfig.from_model(
final_msg.prompt,
prompt_override=(
new_msg_req.prompt_override or chat_session.prompt_override
),
)
if not persona
else PromptConfig.from_model(persona.prompts[0])
)
persona_tool_classes = [
get_tool_cls(tool, db_session) for tool in persona.tools
]
# find out what tools to use
search_tool: SearchTool | None = None
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
for db_tool_model in persona.tools:
# handle in-code tools specially
if db_tool_model.in_code_tool_id:
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
search_tool = SearchTool(
db_session=db_session,
user=user,
persona=persona,
retrieval_options=retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
)
tool_dict[db_tool_model.id] = [search_tool]
elif tool_cls.__name__ == ImageGenerationTool.__name__:
img_generation_llm_config: LLMConfig | None = None
if (
llm
and llm.config.api_key
and llm.config.model_provider == "openai"
):
img_generation_llm_config = llm.config
else:
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
iter(
[
llm_provider
for llm_provider in llm_providers
if llm_provider.provider == "openai"
]
),
None,
)
if not openai_provider or not openai_provider.api_key:
raise ValueError(
"Image generation tool requires an OpenAI API key"
)
img_generation_llm_config = LLMConfig(
model_provider=openai_provider.provider,
model_name=openai_provider.default_model_name,
temperature=GEN_AI_TEMPERATURE,
api_key=openai_provider.api_key,
api_base=openai_provider.api_base,
api_version=openai_provider.api_version,
)
tool_dict[db_tool_model.id] = [
ImageGenerationTool(
api_key=cast(str, img_generation_llm_config.api_key),
api_base=img_generation_llm_config.api_base,
api_version=img_generation_llm_config.api_version,
additional_headers=litellm_additional_headers,
)
]
elif tool_cls.__name__ == InternetSearchTool.__name__:
bing_api_key = BING_API_KEY
if not bing_api_key:
raise ValueError(
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
)
tool_dict[db_tool_model.id] = [
InternetSearchTool(api_key=bing_api_key)
]
continue
# handle all custom tools
if db_tool_model.openapi_schema:
tool_dict[db_tool_model.id] = cast(
list[Tool],
build_custom_tools_from_openapi_schema(
db_tool_model.openapi_schema
),
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)
# factor in tool definition size when pruning
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
persona_tool_classes
)
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(tools)
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
)
# NOTE: for now, only support SearchTool and ImageGenerationTool
# in the future, will support arbitrary user-defined tools
search_tool: SearchTool | None = None
tools: list[Tool] = []
for tool_cls in persona_tool_classes:
if tool_cls.__name__ == SearchTool.__name__:
search_tool = SearchTool(
db_session=db_session,
user=user,
persona=persona,
retrieval_options=retrieval_options,
prompt_config=prompt_config,
llm_config=llm.config,
pruning_config=document_pruning_config,
selected_docs=selected_llm_docs,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
)
tools.append(search_tool)
elif tool_cls.__name__ == ImageGenerationTool.__name__:
dalle_key = None
if llm and llm.config.api_key and llm.config.model_provider == "openai":
dalle_key = llm.config.api_key
else:
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
iter(
[
llm_provider
for llm_provider in llm_providers
if llm_provider.provider == "openai"
]
),
None,
)
if not openai_provider or not openai_provider.api_key:
raise ValueError(
"Image generation tool requires an OpenAI API key"
)
dalle_key = openai_provider.api_key
tools.append(ImageGenerationTool(api_key=dalle_key))
# LLM prompt building, response capturing, etc.
answer = Answer(
question=final_msg.message,
@@ -437,33 +562,61 @@ def stream_chat_message_objects(
prompt_config=prompt_config,
llm=(
llm
or get_llm_for_persona(
persona, new_msg_req.llm_override or chat_session.llm_override
or get_main_llm_from_tuple(
get_llms_for_persona(
persona=persona,
llm_override=(
new_msg_req.llm_override or chat_session.llm_override
),
additional_headers=litellm_additional_headers,
)
)
),
message_history=[
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
],
tools=tools,
force_use_tool=_check_should_force_search(new_msg_req),
force_use_tool=(
_check_should_force_search(new_msg_req)
if search_tool and len(tools) == 1
else None
),
)
reference_db_search_docs = None
qa_docs_response = None
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
dropped_indices = None
tool_result = None
for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse):
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
(
qa_docs_response,
reference_db_search_docs,
dropped_indices,
) = _handle_search_tool_response_summary(
packet, db_session, selected_db_search_docs
packet=packet,
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
dedupe_docs=retrieval_options.dedupe_docs
if retrieval_options
else False,
)
yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
chunk_indices = packet.response
if reference_db_search_docs is not None and dropped_indices:
chunk_indices = drop_llm_indices(
llm_indices=chunk_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_indices,
)
yield LLMRelevanceFilterResponse(
relevant_chunk_indices=packet.response
relevant_chunk_indices=chunk_indices
)
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
@@ -480,16 +633,41 @@ def stream_chat_message_objects(
yield ImageGenerationDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
(
qa_docs_response,
reference_db_search_docs,
) = _handle_internet_search_tool_response_summary(
packet=packet,
db_session=db_session,
)
yield qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
else:
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
yield cast(ChatPacket, packet)
except Exception as e:
logger.exception(e)
logger.exception("Failed to process chat message")
# Frontend will erase whatever answer and show this instead
# This will be the issue 99% of the time
yield StreamingError(error="LLM failed to respond, have you set your API key?")
# Don't leak the API key
error_msg = str(e)
if llm.config.api_key and llm.config.api_key.lower() in error_msg.lower():
error_msg = (
f"LLM failed to respond. Invalid API "
f"key error from '{llm.config.model_provider}'."
)
yield StreamingError(error=error_msg)
# Cancel the transaction so that no messages are saved
db_session.rollback()
return
# Post-LLM answer processing
@@ -502,6 +680,11 @@ def stream_chat_message_objects(
)
# Saving Gen AI answer and responding with message info
tool_name_to_tool_id: dict[str, int] = {}
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name] = tool_id
gen_ai_response_message = partial_response(
message=answer.llm_answer,
rephrased_query=(
@@ -512,7 +695,18 @@ def stream_chat_message_objects(
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=db_citations,
error=None,
tool_calls=[
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
if tool_result
else [],
)
db_session.commit() # actually save user / assistant message
msg_detail_response = translate_db_message_to_chat_message_detail(
gen_ai_response_message
@@ -531,6 +725,7 @@ def stream_chat_message(
new_msg_req: CreateChatMessageRequest,
user: User | None,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
) -> Iterator[str]:
with get_session_context_manager() as db_session:
objects = stream_chat_message_objects(
@@ -538,6 +733,7 @@ def stream_chat_message(
user=user,
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
)
for obj in objects:
yield get_json_line(obj.dict())

View File

@@ -8,6 +8,7 @@ prompts:
# System Prompt (as shown in UI)
system: >
You are a question answering system that is constantly learning and improving.
The current date is DANSWER_DATETIME_REPLACEMENT.
You can process and comprehend vast amounts of text and utilize this knowledge to provide
grounded, accurate, and concise answers to diverse queries.
@@ -21,8 +22,9 @@ prompts:
I have not read or seen any of the documents and do not want to read them.
If there are no relevant documents, refer to the chat history and existing knowledge.
If there are no relevant documents, refer to the chat history and your internal knowledge.
# Inject a statement at the end of system prompt to inform the LLM of the current date/time
# If the DANSWER_DATETIME_REPLACEMENT is set, the date/time is inserted there instead
# Format looks like: "October 16, 2023 14:30"
datetime_aware: true
# Prompts the LLM to include citations in the for [1], [2] etc.
@@ -32,7 +34,16 @@ prompts:
- name: "OnlyLLM"
description: "Chat directly with the LLM!"
system: "You are a helpful assistant."
system: >
You are a helpful AI assistant. The current date is DANSWER_DATETIME_REPLACEMENT
You give concise responses to very simple questions, but provide more thorough responses to
more complex and open-ended questions.
You are happy to help with writing, analysis, question answering, math, coding and all sorts
of other tasks. You use markdown where reasonable and also for coding.
task: ""
datetime_aware: true
include_citations: true
@@ -43,10 +54,11 @@ prompts:
system: >
You are a text summarizing assistant that highlights the most important knowledge from the
context provided, prioritizing the information that relates to the user query.
The current date is DANSWER_DATETIME_REPLACEMENT.
You ARE NOT creative and always stick to the provided documents.
If there are no documents, refer to the conversation history.
IMPORTANT: YOU ONLY SUMMARIZE THE IMPORTANT INFORMATION FROM THE PROVIDED DOCUMENTS,
NEVER USE YOUR OWN KNOWLEDGE.
task: >
@@ -61,7 +73,8 @@ prompts:
description: "Recites information from retrieved context! Least creative but most safe!"
system: >
Quote and cite relevant information from provided context based on the user query.
The current date is DANSWER_DATETIME_REPLACEMENT.
You only provide quotes that are EXACT substrings from provided documents!
If there are no documents provided,

View File

@@ -4,6 +4,7 @@ import urllib.parse
from danswer.configs.constants import AuthType
from danswer.configs.constants import DocumentIndexType
from danswer.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
#####
# App Configs
@@ -45,13 +46,14 @@ DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
# Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive
# information. This provides an extra layer of security on top of Postgres access controls
# and is available in Danswer EE
ENCRYPTION_KEY_SECRET = os.environ.get("ENCRYPTION_KEY_SECRET")
ENCRYPTION_KEY_SECRET = os.environ.get("ENCRYPTION_KEY_SECRET") or ""
# Turn off mask if admin users should see full credentials for data connectors.
MASK_CREDENTIAL_PREFIX = (
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
)
SESSION_EXPIRE_TIME_SECONDS = int(
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
) # 7 days
@@ -160,6 +162,11 @@ WEB_CONNECTOR_OAUTH_CLIENT_SECRET = os.environ.get("WEB_CONNECTOR_OAUTH_CLIENT_S
WEB_CONNECTOR_OAUTH_TOKEN_URL = os.environ.get("WEB_CONNECTOR_OAUTH_TOKEN_URL")
WEB_CONNECTOR_VALIDATE_URLS = os.environ.get("WEB_CONNECTOR_VALIDATE_URLS")
HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY = os.environ.get(
"HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY",
HtmlBasedConnectorTransformLinksStrategy.STRIP,
)
NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP = (
os.environ.get("NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP", "").lower()
== "true"
@@ -178,6 +185,12 @@ CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES = (
os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES", "").lower() == "true"
)
# Save pages labels as Danswer metadata tags
# The reason to skip this would be to reduce the number of calls to Confluence due to rate limit concerns
CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING = (
os.environ.get("CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING", "").lower() == "true"
)
JIRA_CONNECTOR_LABELS_TO_SKIP = [
ignored_tag
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
@@ -188,6 +201,10 @@ GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
GITHUB_CONNECTOR_BASE_URL = os.environ.get("GITHUB_CONNECTOR_BASE_URL") or None
GITLAB_CONNECTOR_INCLUDE_CODE_FILES = (
os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true"
)
DASK_JOB_CLIENT_ENABLED = (
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
)
@@ -195,6 +212,22 @@ EXPERIMENTAL_CHECKPOINTING_ENABLED = (
os.environ.get("EXPERIMENTAL_CHECKPOINTING_ENABLED", "").lower() == "true"
)
DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day
PREVENT_SIMULTANEOUS_PRUNING = (
os.environ.get("PREVENT_SIMULTANEOUS_PRUNING", "").lower() == "true"
)
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE = int(
os.environ.get("MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE", 0)
)
# comma delimited list of zendesk article labels to skip indexing for
ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS = os.environ.get(
"ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS", ""
).split(",")
#####
# Indexing Configs
@@ -215,19 +248,17 @@ DISABLE_INDEX_UPDATE_ON_SWAP = (
# fairly large amount of memory in order to increase substantially, since
# each worker loads the embedding models into memory.
NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1)
CHUNK_OVERLAP = 0
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
ENABLE_MINI_CHUNK = os.environ.get("ENABLE_MINI_CHUNK", "").lower() == "true"
# Finer grained chunking for more detail retention
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end
MINI_CHUNK_SIZE = 150
# Include the document level metadata in each chunk. If the metadata is too long, then it is thrown out
# We don't want the metadata to overwhelm the actual contents of the chunk
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"
# Timeout to wait for job's last update before killing it, in hours
CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 3))
# If set to true, then will not clean up documents that "no longer exist" when running Load connectors
DISABLE_DOCUMENT_CLEANUP = (
os.environ.get("DISABLE_DOCUMENT_CLEANUP", "").lower() == "true"
)
#####
@@ -242,15 +273,20 @@ JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
CURRENT_PROCESS_IS_AN_INDEXING_JOB = (
os.environ.get("CURRENT_PROCESS_IS_AN_INDEXING_JOB", "").lower() == "true"
)
# Logs every model prompt and output, mostly used for development or exploration purposes
# Sets LiteLLM to verbose logging
LOG_ALL_MODEL_INTERACTIONS = (
os.environ.get("LOG_ALL_MODEL_INTERACTIONS", "").lower() == "true"
)
# Logs Danswer only model interactions like prompts, responses, messages etc.
LOG_DANSWER_MODEL_INTERACTIONS = (
os.environ.get("LOG_DANSWER_MODEL_INTERACTIONS", "").lower() == "true"
)
# If set to `true` will enable additional logs about Vespa query performance
# (time spent on finding the right docs + time spent fetching summaries from disk)
LOG_VESPA_TIMING_INFORMATION = (
os.environ.get("LOG_VESPA_TIMING_INFORMATION", "").lower() == "true"
)
LOG_ENDPOINT_LATENCY = os.environ.get("LOG_ENDPOINT_LATENCY", "").lower() == "true"
# Anonymous usage telemetry
DISABLE_TELEMETRY = os.environ.get("DISABLE_TELEMETRY", "").lower() == "true"
@@ -263,3 +299,15 @@ TOKEN_BUDGET_GLOBALLY_ENABLED = (
CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
os.environ.get("CUSTOM_ANSWER_VALIDITY_CONDITIONS", "[]")
)
#####
# Enterprise Edition Configs
#####
# NOTE: this should only be enabled if you have purchased an enterprise license.
# if you're interested in an enterprise license, please reach out to us at
# founders@danswer.ai OR message Chris Weaver or Yuhong Sun in the Danswer
# Slack community (https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ)
ENTERPRISE_EDITION_ENABLED = (
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
)

View File

@@ -5,7 +5,10 @@ PROMPTS_YAML = "./danswer/chat/prompts.yaml"
PERSONAS_YAML = "./danswer/chat/personas.yaml"
NUM_RETURNED_HITS = 50
NUM_RERANKED_RESULTS = 15
# Used for LLM filtering and reranking
# We want this to be approximately the number of results we want to show on the first page
# It cannot be too large due to cost and latency implications
NUM_RERANKED_RESULTS = 20
# May be less depending on model
MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
@@ -25,9 +28,10 @@ BASE_RECENCY_DECAY = 0.5
FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
# Currently this next one is not configurable via env
DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
DISABLE_LLM_FILTER_EXTRACTION = (
os.environ.get("DISABLE_LLM_FILTER_EXTRACTION", "").lower() == "true"
)
# For the highest matching base size chunk, how many chunks above and below do we pull in by default
# Note this is not in any of the deployment configs yet
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 0)
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 0)
# Whether the LLM should evaluate all of the document chunks passed in for usefulness
# in relation to the user query
DISABLE_LLM_CHUNK_FILTER = (
@@ -43,8 +47,6 @@ DISABLE_LLM_QUERY_REPHRASE = (
# 1 edit per 20 characters, currently unused due to fuzzy match being too slow
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
# Include additional document/chunk metadata in prompt to GenerativeAI
INCLUDE_METADATA = False
# Keyword Search Drop Stopwords
# If user has changed the default model, would most likely be to use a multilingual
# model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords
@@ -64,9 +66,31 @@ TITLE_CONTENT_RATIO = max(
# A list of languages passed to the LLM to rephase the query
# For example "English,French,Spanish", be sure to use the "," separator
MULTILINGUAL_QUERY_EXPANSION = os.environ.get("MULTILINGUAL_QUERY_EXPANSION") or None
LANGUAGE_HINT = "\n" + (
os.environ.get("LANGUAGE_HINT")
or "IMPORTANT: Respond in the same language as my query!"
)
LANGUAGE_CHAT_NAMING_HINT = (
os.environ.get("LANGUAGE_CHAT_NAMING_HINT")
or "The name of the conversation must be in the same language as the user query."
)
# Agentic search takes significantly more tokens and therefore has much higher cost.
# This configuration allows users to get a search-only experience with instant results
# and no involvement from the LLM.
# Additionally, some LLM providers have strict rate limits which may prohibit
# sending many API requests at once (as is done in agentic search).
DISABLE_AGENTIC_SEARCH = (
os.environ.get("DISABLE_AGENTIC_SEARCH") or "false"
).lower() == "true"
# Stops streaming answers back to the UI if this pattern is seen:
STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None
# The backend logic for this being True isn't fully supported yet
HARD_DELETE_CHATS = False
# Internet Search
BING_API_KEY = os.environ.get("BING_API_KEY") or None

View File

@@ -19,6 +19,7 @@ DOCUMENT_SETS = "document_sets"
TIME_FILTER = "time_filter"
METADATA = "metadata"
METADATA_LIST = "metadata_list"
METADATA_SUFFIX = "metadata_suffix"
MATCH_HIGHLIGHTS = "match_highlights"
# stored in the `metadata` of a chunk. Used to signify that this chunk should
# not be used for QA. For example, Google Drive file types which can't be parsed
@@ -41,17 +42,16 @@ DEFAULT_BOOST = 0
SESSION_KEY = "session"
QUERY_EVENT_ID = "query_event_id"
LLM_CHUNKS = "llm_chunks"
TOKEN_BUDGET = "token_budget"
TOKEN_BUDGET_TIME_PERIOD = "token_budget_time_period"
ENABLE_TOKEN_BUDGET = "enable_token_budget"
TOKEN_BUDGET_SETTINGS = "token_budget_settings"
# For chunking/processing chunks
TITLE_SEPARATOR = "\n\r\n"
MAX_CHUNK_TITLE_LEN = 1000
RETURN_SEPARATOR = "\n\r\n"
SECTION_SEPARATOR = "\n\n"
# For combining attributes, doesn't have to be unique/perfect to work
INDEX_SEPARATOR = "==="
# For File Connector Metadata override file
DANSWER_METADATA_FILENAME = ".danswer_metadata.json"
# Messages
DISABLED_GEN_AI_MSG = (
@@ -93,9 +93,30 @@ class DocumentSource(str, Enum):
GOOGLE_SITES = "google_sites"
ZENDESK = "zendesk"
LOOPIO = "loopio"
DROPBOX = "dropbox"
SHAREPOINT = "sharepoint"
TEAMS = "teams"
SALESFORCE = "salesforce"
DISCOURSE = "discourse"
AXERO = "axero"
CLICKUP = "clickup"
MEDIAWIKI = "mediawiki"
WIKIPEDIA = "wikipedia"
S3 = "s3"
R2 = "r2"
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
OCI_STORAGE = "oci_storage"
NOT_APPLICABLE = "not_applicable"
class BlobType(str, Enum):
R2 = "r2"
S3 = "s3"
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
OCI_STORAGE = "oci_storage"
# Special case, for internet search
NOT_APPLICABLE = "not_applicable"
class DocumentIndexType(str, Enum):
@@ -111,6 +132,11 @@ class AuthType(str, Enum):
SAML = "saml"
class QAFeedbackType(str, Enum):
LIKE = "like" # User likes the answer, used for metrics
DISLIKE = "dislike" # User dislikes the answer, used for metrics
class SearchFeedbackType(str, Enum):
ENDORSE = "endorse" # boost this document for all future queries
REJECT = "reject" # down-boost this document for all future queries
@@ -136,3 +162,5 @@ class FileOrigin(str, Enum):
CHAT_UPLOAD = "chat_upload"
CHAT_IMAGE_GEN = "chat_image_gen"
CONNECTOR = "connector"
GENERATED_REPORT = "generated_report"
OTHER = "other"

View File

@@ -47,10 +47,6 @@ DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
DANSWER_BOT_RESPOND_EVERY_CHANNEL = (
os.environ.get("DANSWER_BOT_RESPOND_EVERY_CHANNEL", "").lower() == "true"
)
# Auto detect query options like time cutoff or heavily favor recently updated docs
DISABLE_DANSWER_BOT_FILTER_DETECT = (
os.environ.get("DISABLE_DANSWER_BOT_FILTER_DETECT", "").lower() == "true"
)
# Add a second LLM call post Answer to verify if the Answer is valid
# Throws out answers that don't directly or fully answer the user query
# This is the default for all DanswerBot channels unless the channel is configured individually
@@ -73,3 +69,7 @@ DANSWER_BOT_MAX_WAIT_TIME = int(os.environ.get("DANSWER_BOT_MAX_WAIT_TIME") or 1
DANSWER_BOT_FEEDBACK_REMINDER = int(
os.environ.get("DANSWER_BOT_FEEDBACK_REMINDER") or 0
)
# Set to True to rephrase the Slack users messages
DANSWER_BOT_REPHRASE_MESSAGE = (
os.environ.get("DANSWER_BOT_REPHRASE_MESSAGE", "").lower() == "true"
)

View File

@@ -39,8 +39,8 @@ ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "passage: ")
# Purely an optimization, memory limitation consideration
BATCH_SIZE_ENCODE_CHUNKS = 8
# For score display purposes, only way is to know the expected ranges
CROSS_ENCODER_RANGE_MAX = 12
CROSS_ENCODER_RANGE_MIN = -12
CROSS_ENCODER_RANGE_MAX = 1
CROSS_ENCODER_RANGE_MIN = 0
# Unused currently, can't be used with the current default encoder model due to its output range
SEARCH_DISTANCE_CUTOFF = 0
@@ -100,7 +100,7 @@ DISABLE_LITELLM_STREAMING = (
).lower() == "true"
# extra headers to pass to LiteLLM
LITELLM_EXTRA_HEADERS = None
LITELLM_EXTRA_HEADERS: dict[str, str] | None = None
_LITELLM_EXTRA_HEADERS_RAW = os.environ.get("LITELLM_EXTRA_HEADERS")
if _LITELLM_EXTRA_HEADERS_RAW:
try:
@@ -113,3 +113,18 @@ if _LITELLM_EXTRA_HEADERS_RAW:
logger.error(
"Failed to parse LITELLM_EXTRA_HEADERS, must be a valid JSON object"
)
# if specified, will pass through request headers to the call to the LLM
LITELLM_PASS_THROUGH_HEADERS: list[str] | None = None
_LITELLM_PASS_THROUGH_HEADERS_RAW = os.environ.get("LITELLM_PASS_THROUGH_HEADERS")
if _LITELLM_PASS_THROUGH_HEADERS_RAW:
try:
LITELLM_PASS_THROUGH_HEADERS = json.loads(_LITELLM_PASS_THROUGH_HEADERS_RAW)
except Exception:
# need to import here to avoid circular imports
from danswer.utils.logger import setup_logger
logger = setup_logger()
logger.error(
"Failed to parse LITELLM_PASS_THROUGH_HEADERS, must be a valid JSON object"
)

View File

@@ -0,0 +1,277 @@
import os
from datetime import datetime
from datetime import timezone
from io import BytesIO
from typing import Any
from typing import Optional
import boto3
from botocore.client import Config
from mypy_boto3_s3 import S3Client
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import BlobType
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import extract_file_text
from danswer.utils.logger import setup_logger
logger = setup_logger()
class BlobStorageConnector(LoadConnector, PollConnector):
def __init__(
self,
bucket_type: str,
bucket_name: str,
prefix: str = "",
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.bucket_type: BlobType = BlobType(bucket_type)
self.bucket_name = bucket_name
self.prefix = prefix if not prefix or prefix.endswith("/") else prefix + "/"
self.batch_size = batch_size
self.s3_client: Optional[S3Client] = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Checks for boto3 credentials based on the bucket type.
(1) R2: Access Key ID, Secret Access Key, Account ID
(2) S3: AWS Access Key ID, AWS Secret Access Key
(3) GOOGLE_CLOUD_STORAGE: Access Key ID, Secret Access Key, Project ID
(4) OCI_STORAGE: Namespace, Region, Access Key ID, Secret Access Key
For each bucket type, the method initializes the appropriate S3 client:
- R2: Uses Cloudflare R2 endpoint with S3v4 signature
- S3: Creates a standard boto3 S3 client
- GOOGLE_CLOUD_STORAGE: Uses Google Cloud Storage endpoint
- OCI_STORAGE: Uses Oracle Cloud Infrastructure Object Storage endpoint
Raises ConnectorMissingCredentialError if required credentials are missing.
Raises ValueError for unsupported bucket types.
"""
logger.info(
f"Loading credentials for {self.bucket_name} or type {self.bucket_type}"
)
if self.bucket_type == BlobType.R2:
if not all(
credentials.get(key)
for key in ["r2_access_key_id", "r2_secret_access_key", "account_id"]
):
raise ConnectorMissingCredentialError("Cloudflare R2")
self.s3_client = boto3.client(
"s3",
endpoint_url=f"https://{credentials['account_id']}.r2.cloudflarestorage.com",
aws_access_key_id=credentials["r2_access_key_id"],
aws_secret_access_key=credentials["r2_secret_access_key"],
region_name="auto",
config=Config(signature_version="s3v4"),
)
elif self.bucket_type == BlobType.S3:
if not all(
credentials.get(key)
for key in ["aws_access_key_id", "aws_secret_access_key"]
):
raise ConnectorMissingCredentialError("Google Cloud Storage")
session = boto3.Session(
aws_access_key_id=credentials["aws_access_key_id"],
aws_secret_access_key=credentials["aws_secret_access_key"],
)
self.s3_client = session.client("s3")
elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE:
if not all(
credentials.get(key) for key in ["access_key_id", "secret_access_key"]
):
raise ConnectorMissingCredentialError("Google Cloud Storage")
self.s3_client = boto3.client(
"s3",
endpoint_url="https://storage.googleapis.com",
aws_access_key_id=credentials["access_key_id"],
aws_secret_access_key=credentials["secret_access_key"],
region_name="auto",
)
elif self.bucket_type == BlobType.OCI_STORAGE:
if not all(
credentials.get(key)
for key in ["namespace", "region", "access_key_id", "secret_access_key"]
):
raise ConnectorMissingCredentialError("Oracle Cloud Infrastructure")
self.s3_client = boto3.client(
"s3",
endpoint_url=f"https://{credentials['namespace']}.compat.objectstorage.{credentials['region']}.oraclecloud.com",
aws_access_key_id=credentials["access_key_id"],
aws_secret_access_key=credentials["secret_access_key"],
region_name=credentials["region"],
)
else:
raise ValueError(f"Unsupported bucket type: {self.bucket_type}")
return None
def _download_object(self, key: str) -> bytes:
if self.s3_client is None:
raise ConnectorMissingCredentialError("Blob storage")
object = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
return object["Body"].read()
# NOTE: Left in as may be useful for one-off access to documents and sharing across orgs.
# def _get_presigned_url(self, key: str) -> str:
# if self.s3_client is None:
# raise ConnectorMissingCredentialError("Blog storage")
# url = self.s3_client.generate_presigned_url(
# "get_object",
# Params={"Bucket": self.bucket_name, "Key": key},
# ExpiresIn=self.presign_length,
# )
# return url
def _get_blob_link(self, key: str) -> str:
if self.s3_client is None:
raise ConnectorMissingCredentialError("Blob storage")
if self.bucket_type == BlobType.R2:
account_id = self.s3_client.meta.endpoint_url.split("//")[1].split(".")[0]
return f"https://{account_id}.r2.cloudflarestorage.com/{self.bucket_name}/{key}"
elif self.bucket_type == BlobType.S3:
region = self.s3_client.meta.region_name
return f"https://{self.bucket_name}.s3.{region}.amazonaws.com/{key}"
elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE:
return f"https://storage.cloud.google.com/{self.bucket_name}/{key}"
elif self.bucket_type == BlobType.OCI_STORAGE:
namespace = self.s3_client.meta.endpoint_url.split("//")[1].split(".")[0]
region = self.s3_client.meta.region_name
return f"https://objectstorage.{region}.oraclecloud.com/n/{namespace}/b/{self.bucket_name}/o/{key}"
else:
raise ValueError(f"Unsupported bucket type: {self.bucket_type}")
def _yield_blob_objects(
self,
start: datetime,
end: datetime,
) -> GenerateDocumentsOutput:
if self.s3_client is None:
raise ConnectorMissingCredentialError("Blog storage")
paginator = self.s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
batch: list[Document] = []
for page in pages:
if "Contents" not in page:
continue
for obj in page["Contents"]:
if obj["Key"].endswith("/"):
continue
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
if not start <= last_modified <= end:
continue
downloaded_file = self._download_object(obj["Key"])
link = self._get_blob_link(obj["Key"])
name = os.path.basename(obj["Key"])
try:
text = extract_file_text(
name,
BytesIO(downloaded_file),
break_on_unprocessable=False,
)
batch.append(
Document(
id=f"{self.bucket_type}:{self.bucket_name}:{obj['Key']}",
sections=[Section(link=link, text=text)],
source=DocumentSource(self.bucket_type.value),
semantic_identifier=name,
doc_updated_at=last_modified,
metadata={},
)
)
if len(batch) == self.batch_size:
yield batch
batch = []
except Exception as e:
logger.exception(
f"Error decoding object {obj['Key']} as UTF-8: {e}"
)
if batch:
yield batch
def load_from_state(self) -> GenerateDocumentsOutput:
logger.info("Loading blob objects")
return self._yield_blob_objects(
start=datetime(1970, 1, 1, tzinfo=timezone.utc),
end=datetime.now(timezone.utc),
)
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
if self.s3_client is None:
raise ConnectorMissingCredentialError("Blog storage")
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
for batch in self._yield_blob_objects(start_datetime, end_datetime):
yield batch
return None
if __name__ == "__main__":
credentials_dict = {
"aws_access_key_id": os.environ.get("AWS_ACCESS_KEY_ID"),
"aws_secret_access_key": os.environ.get("AWS_SECRET_ACCESS_KEY"),
}
# Initialize the connector
connector = BlobStorageConnector(
bucket_type=os.environ.get("BUCKET_TYPE") or "s3",
bucket_name=os.environ.get("BUCKET_NAME") or "test",
prefix="",
)
try:
connector.load_credentials(credentials_dict)
document_batch_generator = connector.load_from_state()
for document_batch in document_batch_generator:
print("First batch of documents:")
for doc in document_batch:
print(f"Document ID: {doc.id}")
print(f"Semantic Identifier: {doc.semantic_identifier}")
print(f"Source: {doc.source}")
print(f"Updated At: {doc.doc_updated_at}")
print("Sections:")
for section in doc.sections:
print(f" - Link: {section.link}")
print(f" - Text: {section.text[:100]}...")
print("---")
break
except ConnectorMissingCredentialError as e:
print(f"Error: {e}")
except Exception as e:
print(f"An unexpected error occurred: {e}")

View File

@@ -0,0 +1,216 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import Optional
import requests
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
CLICKUP_API_BASE_URL = "https://api.clickup.com/api/v2"
class ClickupConnector(LoadConnector, PollConnector):
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
api_token: str | None = None,
team_id: str | None = None,
connector_type: str | None = None,
connector_ids: list[str] | None = None,
retrieve_task_comments: bool = True,
) -> None:
self.batch_size = batch_size
self.api_token = api_token
self.team_id = team_id
self.connector_type = connector_type if connector_type else "workspace"
self.connector_ids = connector_ids
self.retrieve_task_comments = retrieve_task_comments
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.api_token = credentials["clickup_api_token"]
self.team_id = credentials["clickup_team_id"]
return None
@retry_builder()
@rate_limit_builder(max_calls=100, period=60)
def _make_request(self, endpoint: str, params: Optional[dict] = None) -> Any:
if not self.api_token:
raise ConnectorMissingCredentialError("Clickup")
headers = {"Authorization": self.api_token}
response = requests.get(
f"{CLICKUP_API_BASE_URL}/{endpoint}", headers=headers, params=params
)
response.raise_for_status()
return response.json()
def _get_task_comments(self, task_id: str) -> list[Section]:
url_endpoint = f"/task/{task_id}/comment"
response = self._make_request(url_endpoint)
comments = [
Section(
link=f'https://app.clickup.com/t/{task_id}?comment={comment_dict["id"]}',
text=comment_dict["comment_text"],
)
for comment_dict in response["comments"]
]
return comments
def _get_all_tasks_filtered(
self,
start: int | None = None,
end: int | None = None,
) -> GenerateDocumentsOutput:
doc_batch: list[Document] = []
page: int = 0
params = {
"include_markdown_description": "true",
"include_closed": "true",
"page": page,
}
if start is not None:
params["date_updated_gt"] = start
if end is not None:
params["date_updated_lt"] = end
if self.connector_type == "list":
params["list_ids[]"] = self.connector_ids
elif self.connector_type == "folder":
params["project_ids[]"] = self.connector_ids
elif self.connector_type == "space":
params["space_ids[]"] = self.connector_ids
url_endpoint = f"/team/{self.team_id}/task"
while True:
response = self._make_request(url_endpoint, params)
page += 1
params["page"] = page
for task in response["tasks"]:
document = Document(
id=task["id"],
source=DocumentSource.CLICKUP,
semantic_identifier=task["name"],
doc_updated_at=(
datetime.fromtimestamp(
round(float(task["date_updated"]) / 1000, 3)
).replace(tzinfo=timezone.utc)
),
primary_owners=[
BasicExpertInfo(
display_name=task["creator"]["username"],
email=task["creator"]["email"],
)
],
secondary_owners=[
BasicExpertInfo(
display_name=assignee["username"],
email=assignee["email"],
)
for assignee in task["assignees"]
],
title=task["name"],
sections=[
Section(
link=task["url"],
text=(
task["markdown_description"]
if "markdown_description" in task
else task["description"]
),
)
],
metadata={
"id": task["id"],
"status": task["status"]["status"],
"list": task["list"]["name"],
"project": task["project"]["name"],
"folder": task["folder"]["name"],
"space_id": task["space"]["id"],
"tags": [tag["name"] for tag in task["tags"]],
"priority": (
task["priority"]["priority"]
if "priority" in task and task["priority"] is not None
else ""
),
},
)
extra_fields = [
"date_created",
"date_updated",
"date_closed",
"date_done",
"due_date",
]
for extra_field in extra_fields:
if extra_field in task and task[extra_field] is not None:
document.metadata[extra_field] = task[extra_field]
if self.retrieve_task_comments:
document.sections.extend(self._get_task_comments(task["id"]))
doc_batch.append(document)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if response.get("last_page") is True or len(response["tasks"]) < 100:
break
if doc_batch:
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
if self.api_token is None:
raise ConnectorMissingCredentialError("Clickup")
return self._get_all_tasks_filtered(None, None)
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
if self.api_token is None:
raise ConnectorMissingCredentialError("Clickup")
return self._get_all_tasks_filtered(int(start * 1000), int(end * 1000))
if __name__ == "__main__":
import os
clickup_connector = ClickupConnector()
clickup_connector.load_credentials(
{
"clickup_api_token": os.environ["clickup_api_token"],
"clickup_team_id": os.environ["clickup_team_id"],
}
)
latest_docs = clickup_connector.load_from_state()
for doc in latest_docs:
print(doc)

View File

@@ -1,3 +1,5 @@
import io
import os
from collections.abc import Callable
from collections.abc import Collection
from datetime import datetime
@@ -13,6 +15,7 @@ from requests import HTTPError
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
@@ -27,22 +30,25 @@ from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import extract_file_text
from danswer.file_processing.html_utils import format_document_soup
from danswer.utils.logger import setup_logger
logger = setup_logger()
# Potential Improvements
# 1. If wiki page instead of space, do a search of all the children of the page instead of index all in the space
# 2. Include attachments, etc
# 3. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost
# 1. Include attachments, etc
# 2. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str]:
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]:
"""Sample
https://danswer.atlassian.net/wiki/spaces/1234abcd/overview
URL w/ page: https://danswer.atlassian.net/wiki/spaces/1234abcd/pages/5678efgh/overview
URL w/o page: https://danswer.atlassian.net/wiki/spaces/ASAM/overview
wiki_base is https://danswer.atlassian.net/wiki
space is 1234abcd
page_id is 5678efgh
"""
parsed_url = urlparse(wiki_url)
wiki_base = (
@@ -51,18 +57,25 @@ def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str]:
+ parsed_url.netloc
+ parsed_url.path.split("/spaces")[0]
)
space = parsed_url.path.split("/")[3]
return wiki_base, space
path_parts = parsed_url.path.split("/")
space = path_parts[3]
page_id = path_parts[5] if len(path_parts) > 5 else ""
return wiki_base, space, page_id
def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, str]:
def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, str, str]:
"""Sample
https://danswer.ai/confluence/display/1234abcd/overview
URL w/ page https://danswer.ai/confluence/display/1234abcd/pages/5678efgh/overview
URL w/o page https://danswer.ai/confluence/display/1234abcd/overview
wiki_base is https://danswer.ai/confluence
space is 1234abcd
page_id is 5678efgh
"""
# /display/ is always right before the space and at the end of the base url
# /display/ is always right before the space and at the end of the base print()
DISPLAY = "/display/"
PAGE = "/pages/"
parsed_url = urlparse(wiki_url)
wiki_base = (
@@ -72,10 +85,13 @@ def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, st
+ parsed_url.path.split(DISPLAY)[0]
)
space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0]
return wiki_base, space
page_id = ""
if (content := parsed_url.path.split(PAGE)) and len(content) > 1:
page_id = content[1]
return wiki_base, space, page_id
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, bool]:
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]:
is_confluence_cloud = (
".atlassian.net/wiki/spaces/" in wiki_url
or ".jira.com/wiki/spaces/" in wiki_url
@@ -83,15 +99,19 @@ def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, bool]:
try:
if is_confluence_cloud:
wiki_base, space = _extract_confluence_keys_from_cloud_url(wiki_url)
wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url(
wiki_url
)
else:
wiki_base, space = _extract_confluence_keys_from_datacenter_url(wiki_url)
wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url(
wiki_url
)
except Exception as e:
error_msg = f"Not a valid Confluence Wiki Link, unable to extract wiki base and space names. Exception: {e}"
error_msg = f"Not a valid Confluence Wiki Link, unable to extract wiki base, space, and page id. Exception: {e}"
logger.error(error_msg)
raise ValueError(error_msg)
return wiki_base, space, is_confluence_cloud
return wiki_base, space, page_id, is_confluence_cloud
@lru_cache()
@@ -147,6 +167,24 @@ def parse_html_page(text: str, confluence_client: Confluence) -> str:
return format_document_soup(soup)
def get_used_attachments(text: str, confluence_client: Confluence) -> list[str]:
"""Parse a Confluence html page to generate a list of current
attachment in used
Args:
text (str): The page content
confluence_client (Confluence): Confluence client
Returns:
list[str]: List of filename currently in used
"""
files_in_used = []
soup = bs4.BeautifulSoup(text, "html.parser")
for attachment in soup.findAll("ri:attachment"):
files_in_used.append(attachment.attrs["ri:filename"])
return files_in_used
def _comment_dfs(
comments_str: str,
comment_pages: Collection[dict[str, Any]],
@@ -174,10 +212,135 @@ def _comment_dfs(
return comments_str
class RecursiveIndexer:
def __init__(
self,
batch_size: int,
confluence_client: Confluence,
index_origin: bool,
origin_page_id: str,
) -> None:
self.batch_size = 1
# batch_size
self.confluence_client = confluence_client
self.index_origin = index_origin
self.origin_page_id = origin_page_id
self.pages = self.recurse_children_pages(0, self.origin_page_id)
def get_pages(self, ind: int, size: int) -> list[dict]:
if ind * size > len(self.pages):
return []
return self.pages[ind * size : (ind + 1) * size]
def _fetch_origin_page(
self,
) -> dict[str, Any]:
get_page_by_id = make_confluence_call_handle_rate_limit(
self.confluence_client.get_page_by_id
)
try:
origin_page = get_page_by_id(
self.origin_page_id, expand="body.storage.value,version"
)
return origin_page
except Exception as e:
logger.warning(
f"Appending orgin page with id {self.origin_page_id} failed: {e}"
)
return {}
def recurse_children_pages(
self,
start_ind: int,
page_id: str,
) -> list[dict[str, Any]]:
pages: list[dict[str, Any]] = []
current_level_pages: list[dict[str, Any]] = []
next_level_pages: list[dict[str, Any]] = []
# Initial fetch of first level children
index = start_ind
while batch := self._fetch_single_depth_child_pages(
index, self.batch_size, page_id
):
current_level_pages.extend(batch)
index += len(batch)
pages.extend(current_level_pages)
# Recursively index children and children's children, etc.
while current_level_pages:
for child in current_level_pages:
child_index = 0
while child_batch := self._fetch_single_depth_child_pages(
child_index, self.batch_size, child["id"]
):
next_level_pages.extend(child_batch)
child_index += len(child_batch)
pages.extend(next_level_pages)
current_level_pages = next_level_pages
next_level_pages = []
if self.index_origin:
try:
origin_page = self._fetch_origin_page()
pages.append(origin_page)
except Exception as e:
logger.warning(f"Appending origin page with id {page_id} failed: {e}")
return pages
def _fetch_single_depth_child_pages(
self, start_ind: int, batch_size: int, page_id: str
) -> list[dict[str, Any]]:
child_pages: list[dict[str, Any]] = []
get_page_child_by_type = make_confluence_call_handle_rate_limit(
self.confluence_client.get_page_child_by_type
)
try:
child_page = get_page_child_by_type(
page_id,
type="page",
start=start_ind,
limit=batch_size,
expand="body.storage.value,version",
)
child_pages.extend(child_page)
return child_pages
except Exception:
logger.warning(
f"Batch failed with page {page_id} at offset {start_ind} "
f"with size {batch_size}, processing pages individually..."
)
for i in range(batch_size):
ind = start_ind + i
try:
child_page = get_page_child_by_type(
page_id,
type="page",
start=ind,
limit=1,
expand="body.storage.value,version",
)
child_pages.extend(child_page)
except Exception as e:
logger.warning(f"Page {page_id} at offset {ind} failed: {e}")
raise e
return child_pages
class ConfluenceConnector(LoadConnector, PollConnector):
def __init__(
self,
wiki_page_url: str,
index_origin: bool = True,
batch_size: int = INDEX_BATCH_SIZE,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
# if a page has one of the labels specified in this list, we will just
@@ -188,11 +351,27 @@ class ConfluenceConnector(LoadConnector, PollConnector):
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
self.labels_to_skip = set(labels_to_skip)
self.wiki_base, self.space, self.is_cloud = extract_confluence_keys_from_url(
wiki_page_url
)
self.recursive_indexer: RecursiveIndexer | None = None
self.index_origin = index_origin
(
self.wiki_base,
self.space,
self.page_id,
self.is_cloud,
) = extract_confluence_keys_from_url(wiki_page_url)
self.space_level_scan = False
self.confluence_client: Confluence | None = None
if self.page_id is None or self.page_id == "":
self.space_level_scan = True
logger.info(
f"wiki_base: {self.wiki_base}, space: {self.space}, page_id: {self.page_id},"
+ f" space_level_scan: {self.space_level_scan}, origin: {self.index_origin}"
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
username = credentials["confluence_username"]
access_token = credentials["confluence_access_token"]
@@ -210,8 +389,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
self,
confluence_client: Confluence,
start_ind: int,
) -> Collection[dict[str, Any]]:
def _fetch(start_ind: int, batch_size: int) -> Collection[dict[str, Any]]:
) -> list[dict[str, Any]]:
def _fetch_space(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
get_all_pages_from_space = make_confluence_call_handle_rate_limit(
confluence_client.get_all_pages_from_space
)
@@ -220,9 +399,11 @@ class ConfluenceConnector(LoadConnector, PollConnector):
self.space,
start=start_ind,
limit=batch_size,
status="current"
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
else None,
status=(
"current"
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
else None
),
expand="body.storage.value,version",
)
except Exception:
@@ -241,9 +422,11 @@ class ConfluenceConnector(LoadConnector, PollConnector):
self.space,
start=start_ind + i,
limit=1,
status="current"
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
else None,
status=(
"current"
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
else None
),
expand="body.storage.value,version",
)
)
@@ -264,17 +447,41 @@ class ConfluenceConnector(LoadConnector, PollConnector):
return view_pages
def _fetch_page(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
if self.recursive_indexer is None:
self.recursive_indexer = RecursiveIndexer(
origin_page_id=self.page_id,
batch_size=self.batch_size,
confluence_client=self.confluence_client,
index_origin=self.index_origin,
)
return self.recursive_indexer.get_pages(start_ind, batch_size)
pages: list[dict[str, Any]] = []
try:
return _fetch(start_ind, self.batch_size)
pages = (
_fetch_space(start_ind, self.batch_size)
if self.space_level_scan
else _fetch_page(start_ind, self.batch_size)
)
return pages
except Exception as e:
if not self.continue_on_failure:
raise e
# error checking phase, only reachable if `self.continue_on_failure=True`
pages: list[dict[str, Any]] = []
for i in range(self.batch_size):
try:
pages.extend(_fetch(start_ind + i, 1))
pages = (
_fetch_space(start_ind, self.batch_size)
if self.space_level_scan
else _fetch_page(start_ind, self.batch_size)
)
return pages
except Exception:
logger.exception(
"Ran into exception when fetching pages from Confluence"
@@ -286,6 +493,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
get_page_child_by_type = make_confluence_call_handle_rate_limit(
confluence_client.get_page_child_by_type
)
try:
comment_pages = cast(
Collection[dict[str, Any]],
@@ -321,6 +529,50 @@ class ConfluenceConnector(LoadConnector, PollConnector):
logger.exception("Ran into exception when fetching labels from Confluence")
return []
def _fetch_attachments(
self, confluence_client: Confluence, page_id: str, files_in_used: list[str]
) -> str:
get_attachments_from_content = make_confluence_call_handle_rate_limit(
confluence_client.get_attachments_from_content
)
files_attachment_content: list = []
try:
attachments_container = get_attachments_from_content(
page_id, start=0, limit=500
)
for attachment in attachments_container["results"]:
if attachment["metadata"]["mediaType"] in [
"image/jpeg",
"image/png",
"image/gif",
"image/svg+xml",
"video/mp4",
"video/quicktime",
]:
continue
if attachment["title"] not in files_in_used:
continue
download_link = confluence_client.url + attachment["_links"]["download"]
response = confluence_client._session.get(download_link)
if response.status_code == 200:
extract = extract_file_text(
attachment["title"], io.BytesIO(response.content), False
)
files_attachment_content.append(extract)
except Exception as e:
if not self.continue_on_failure:
raise e
logger.exception(
f"Ran into exception when fetching attachments from Confluence: {e}"
)
return "\n".join(files_attachment_content)
def _get_doc_batch(
self, start_ind: int, time_filter: Callable[[datetime], bool] | None = None
) -> tuple[list[Document], int]:
@@ -328,8 +580,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
batch = self._fetch_pages(self.confluence_client, start_ind)
for page in batch:
last_modified_str = page["version"]["when"]
author = cast(str | None, page["version"].get("by", {}).get("email"))
@@ -345,15 +597,18 @@ class ConfluenceConnector(LoadConnector, PollConnector):
if time_filter is None or time_filter(last_modified):
page_id = page["id"]
if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
page_labels = self._fetch_labels(self.confluence_client, page_id)
# check disallowed labels
if self.labels_to_skip:
page_labels = self._fetch_labels(self.confluence_client, page_id)
label_intersection = self.labels_to_skip.intersection(page_labels)
if label_intersection:
logger.info(
f"Page with ID '{page_id}' has a label which has been "
f"designated as disallowed: {label_intersection}. Skipping."
)
continue
page_html = (
@@ -366,8 +621,19 @@ class ConfluenceConnector(LoadConnector, PollConnector):
logger.debug("Page is empty, skipping: %s", page_url)
continue
page_text = parse_html_page(page_html, self.confluence_client)
files_in_used = get_used_attachments(page_html, self.confluence_client)
attachment_text = self._fetch_attachments(
self.confluence_client, page_id, files_in_used
)
page_text += attachment_text
comments_text = self._fetch_comments(self.confluence_client, page_id)
page_text += comments_text
doc_metadata: dict[str, str | list[str]] = {
"Wiki Space Name": self.space
}
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING and page_labels:
doc_metadata["labels"] = page_labels
doc_batch.append(
Document(
@@ -376,12 +642,10 @@ class ConfluenceConnector(LoadConnector, PollConnector):
source=DocumentSource.CONFLUENCE,
semantic_identifier=page["title"],
doc_updated_at=last_modified,
primary_owners=[BasicExpertInfo(email=author)]
if author
else None,
metadata={
"Wiki Space Name": self.space,
},
primary_owners=(
[BasicExpertInfo(email=author)] if author else None
),
metadata=doc_metadata,
)
)
return doc_batch, len(batch)
@@ -423,8 +687,6 @@ class ConfluenceConnector(LoadConnector, PollConnector):
if __name__ == "__main__":
import os
connector = ConfluenceConnector(os.environ["CONFLUENCE_TEST_SPACE_URL"])
connector.load_credentials(
{

View File

@@ -1,10 +1,14 @@
import time
from collections.abc import Callable
from typing import Any
from typing import cast
from typing import TypeVar
from requests import HTTPError
from retry import retry
from danswer.utils.logger import setup_logger
logger = setup_logger()
F = TypeVar("F", bound=Callable[..., Any])
@@ -18,23 +22,38 @@ class ConfluenceRateLimitError(Exception):
def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
@retry(
exceptions=ConfluenceRateLimitError,
tries=10,
delay=1,
max_delay=600, # 10 minutes
backoff=2,
jitter=1,
)
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
try:
return confluence_call(*args, **kwargs)
except HTTPError as e:
if (
e.response.status_code == 429
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
):
raise ConfluenceRateLimitError()
raise
starting_delay = 5
backoff = 2
max_delay = 600
for attempt in range(10):
try:
return confluence_call(*args, **kwargs)
except HTTPError as e:
if (
e.response.status_code == 429
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
):
retry_after = None
try:
retry_after = int(e.response.headers.get("Retry-After"))
except (ValueError, TypeError):
pass
if retry_after:
logger.warning(
f"Rate limit hit. Retrying after {retry_after} seconds..."
)
time.sleep(retry_after)
else:
logger.warning(
"Rate limit hit. Retrying with exponential backoff..."
)
delay = min(starting_delay * (backoff**attempt), max_delay)
time.sleep(delay)
else:
# re-raise, let caller handle
raise
return cast(F, wrapped_call)

View File

@@ -6,6 +6,7 @@ from typing import TypeVar
from dateutil.parser import parse
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.connectors.models import BasicExpertInfo
from danswer.utils.text_processing import is_valid_email
@@ -57,3 +58,7 @@ def process_in_batches(
) -> Iterator[list[U]]:
for i in range(0, len(objects), batch_size):
yield [process_function(obj) for obj in objects[i : i + batch_size]]
def get_metadata_keys_to_ignore() -> list[str]:
return [IGNORE_FOR_QA]

View File

@@ -11,6 +11,9 @@ from requests import Response
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import PollConnector
@@ -58,63 +61,36 @@ class DiscourseConnector(PollConnector):
self.category_id_map: dict[int, str] = {}
self.batch_size = batch_size
self.permissions: DiscoursePerms | None = None
self.active_categories: set | None = None
@rate_limit_builder(max_calls=100, period=60)
def _make_request(self, endpoint: str, params: dict | None = None) -> Response:
if not self.permissions:
raise ConnectorMissingCredentialError("Discourse")
return discourse_request(endpoint, self.permissions, params)
def _get_categories_map(
self,
) -> None:
assert self.permissions is not None
categories_endpoint = urllib.parse.urljoin(self.base_url, "categories.json")
response = discourse_request(
response = self._make_request(
endpoint=categories_endpoint,
perms=self.permissions,
params={"include_subcategories": True},
)
categories = response.json()["category_list"]["categories"]
self.category_id_map = {
category["id"]: category["name"]
for category in categories
if not self.categories or category["name"].lower() in self.categories
cat["id"]: cat["name"]
for cat in categories
if not self.categories or cat["name"].lower() in self.categories
}
def _get_latest_topics(
self, start: datetime | None, end: datetime | None
) -> list[int]:
assert self.permissions is not None
topic_ids = []
valid_categories = set(self.category_id_map.keys())
latest_endpoint = urllib.parse.urljoin(self.base_url, "latest.json")
response = discourse_request(endpoint=latest_endpoint, perms=self.permissions)
topics = response.json()["topic_list"]["topics"]
for topic in topics:
last_time = topic.get("last_posted_at")
if not last_time:
continue
last_time_dt = time_str_to_utc(last_time)
if start and start > last_time_dt:
continue
if end and end < last_time_dt:
continue
if valid_categories and topic.get("category_id") not in valid_categories:
continue
topic_ids.append(topic["id"])
return topic_ids
self.active_categories = set(self.category_id_map)
def _get_doc_from_topic(self, topic_id: int) -> Document:
assert self.permissions is not None
topic_endpoint = urllib.parse.urljoin(self.base_url, f"t/{topic_id}.json")
response = discourse_request(
endpoint=topic_endpoint,
perms=self.permissions,
)
response = self._make_request(endpoint=topic_endpoint)
topic = response.json()
topic_url = urllib.parse.urljoin(self.base_url, f"t/{topic['slug']}")
@@ -138,10 +114,16 @@ class DiscourseConnector(PollConnector):
sections.append(
Section(link=topic_url, text=parse_html_page_basic(post["cooked"]))
)
category_name = self.category_id_map.get(topic["category_id"])
metadata: dict[str, str | list[str]] = (
{
"category": category_name,
}
if category_name
else {}
)
metadata: dict[str, str | list[str]] = {
"category": self.category_id_map[topic["category_id"]],
}
if topic.get("tags"):
metadata["tags"] = topic["tags"]
@@ -157,26 +139,78 @@ class DiscourseConnector(PollConnector):
)
return doc
def _get_latest_topics(
self, start: datetime | None, end: datetime | None, page: int
) -> list[int]:
assert self.permissions is not None
topic_ids = []
if not self.categories:
latest_endpoint = urllib.parse.urljoin(
self.base_url, f"latest.json?page={page}"
)
response = self._make_request(endpoint=latest_endpoint)
topics = response.json()["topic_list"]["topics"]
else:
topics = []
empty_categories = []
for category_id in self.category_id_map.keys():
category_endpoint = urllib.parse.urljoin(
self.base_url, f"c/{category_id}.json?page={page}&sys=latest"
)
response = self._make_request(endpoint=category_endpoint)
new_topics = response.json()["topic_list"]["topics"]
if len(new_topics) == 0:
empty_categories.append(category_id)
topics.extend(new_topics)
for empty_category in empty_categories:
self.category_id_map.pop(empty_category)
for topic in topics:
last_time = topic.get("last_posted_at")
if not last_time:
continue
last_time_dt = time_str_to_utc(last_time)
if (start and start > last_time_dt) or (end and end < last_time_dt):
continue
topic_ids.append(topic["id"])
if len(topic_ids) >= self.batch_size:
break
return topic_ids
def _yield_discourse_documents(
self, topic_ids: list[int]
self,
start: datetime,
end: datetime,
) -> GenerateDocumentsOutput:
doc_batch: list[Document] = []
for topic_id in topic_ids:
doc_batch.append(self._get_doc_from_topic(topic_id))
page = 1
while topic_ids := self._get_latest_topics(start, end, page):
doc_batch: list[Document] = []
for topic_id in topic_ids:
doc_batch.append(self._get_doc_from_topic(topic_id))
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if len(doc_batch) >= self.batch_size:
if doc_batch:
yield doc_batch
doc_batch = []
page += 1
if doc_batch:
yield doc_batch
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
def load_credentials(
self,
credentials: dict[str, Any],
) -> dict[str, Any] | None:
self.permissions = DiscoursePerms(
api_key=credentials["discourse_api_key"],
api_username=credentials["discourse_api_username"],
)
return None
def poll_source(
@@ -184,16 +218,13 @@ class DiscourseConnector(PollConnector):
) -> GenerateDocumentsOutput:
if self.permissions is None:
raise ConnectorMissingCredentialError("Discourse")
start_datetime = datetime.utcfromtimestamp(start).replace(tzinfo=timezone.utc)
end_datetime = datetime.utcfromtimestamp(end).replace(tzinfo=timezone.utc)
self._get_categories_map()
latest_topic_ids = self._get_latest_topics(
start=start_datetime, end=end_datetime
)
yield from self._yield_discourse_documents(latest_topic_ids)
yield from self._yield_discourse_documents(start_datetime, end_datetime)
if __name__ == "__main__":
@@ -209,7 +240,5 @@ if __name__ == "__main__":
current = time.time()
one_year_ago = current - 24 * 60 * 60 * 360
latest_docs = connector.poll_source(one_year_ago, current)
print(next(latest_docs))

View File

@@ -0,0 +1,155 @@
from datetime import timezone
from io import BytesIO
from typing import Any
from dropbox import Dropbox # type: ignore
from dropbox.exceptions import ApiError # type:ignore
from dropbox.files import FileMetadata # type:ignore
from dropbox.files import FolderMetadata # type:ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import extract_file_text
from danswer.utils.logger import setup_logger
logger = setup_logger()
class DropboxConnector(LoadConnector, PollConnector):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
self.dropbox_client: Dropbox | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.dropbox_client = Dropbox(credentials["dropbox_access_token"])
return None
def _download_file(self, path: str) -> bytes:
"""Download a single file from Dropbox."""
if self.dropbox_client is None:
raise ConnectorMissingCredentialError("Dropbox")
_, resp = self.dropbox_client.files_download(path)
return resp.content
def _get_shared_link(self, path: str) -> str:
"""Create a shared link for a file in Dropbox."""
if self.dropbox_client is None:
raise ConnectorMissingCredentialError("Dropbox")
try:
# Check if a shared link already exists
shared_links = self.dropbox_client.sharing_list_shared_links(path=path)
if shared_links.links:
return shared_links.links[0].url
link_metadata = (
self.dropbox_client.sharing_create_shared_link_with_settings(path)
)
return link_metadata.url
except ApiError as err:
logger.exception(f"Failed to create a shared link for {path}: {err}")
return ""
def _yield_files_recursive(
self,
path: str,
start: SecondsSinceUnixEpoch | None,
end: SecondsSinceUnixEpoch | None,
) -> GenerateDocumentsOutput:
"""Yield files in batches from a specified Dropbox folder, including subfolders."""
if self.dropbox_client is None:
raise ConnectorMissingCredentialError("Dropbox")
result = self.dropbox_client.files_list_folder(
path,
limit=self.batch_size,
recursive=False,
include_non_downloadable_files=False,
)
while True:
batch: list[Document] = []
for entry in result.entries:
if isinstance(entry, FileMetadata):
modified_time = entry.client_modified
if modified_time.tzinfo is None:
# If no timezone info, assume it is UTC
modified_time = modified_time.replace(tzinfo=timezone.utc)
else:
# If not in UTC, translate it
modified_time = modified_time.astimezone(timezone.utc)
time_as_seconds = int(modified_time.timestamp())
if start and time_as_seconds < start:
continue
if end and time_as_seconds > end:
continue
downloaded_file = self._download_file(entry.path_display)
link = self._get_shared_link(entry.path_display)
try:
text = extract_file_text(
entry.name,
BytesIO(downloaded_file),
break_on_unprocessable=False,
)
batch.append(
Document(
id=f"doc:{entry.id}",
sections=[Section(link=link, text=text)],
source=DocumentSource.DROPBOX,
semantic_identifier=entry.name,
doc_updated_at=modified_time,
metadata={"type": "article"},
)
)
except Exception as e:
logger.exception(
f"Error decoding file {entry.path_display} as utf-8 error occurred: {e}"
)
elif isinstance(entry, FolderMetadata):
yield from self._yield_files_recursive(entry.path_lower, start, end)
if batch:
yield batch
if not result.has_more:
break
result = self.dropbox_client.files_list_folder_continue(result.cursor)
def load_from_state(self) -> GenerateDocumentsOutput:
return self.poll_source(None, None)
def poll_source(
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
if self.dropbox_client is None:
raise ConnectorMissingCredentialError("Dropbox")
for batch in self._yield_files_recursive("", start, end):
yield batch
return None
if __name__ == "__main__":
import os
connector = DropboxConnector()
connector.load_credentials(
{
"dropbox_access_token": os.environ["DROPBOX_ACCESS_TOKEN"],
}
)
document_batches = connector.load_from_state()
print(next(document_batches))

View File

@@ -1,13 +1,18 @@
from typing import Any
from typing import Type
from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource
from danswer.connectors.axero.connector import AxeroConnector
from danswer.connectors.blob.connector import BlobStorageConnector
from danswer.connectors.bookstack.connector import BookstackConnector
from danswer.connectors.clickup.connector import ClickupConnector
from danswer.connectors.confluence.connector import ConfluenceConnector
from danswer.connectors.danswer_jira.connector import JiraConnector
from danswer.connectors.discourse.connector import DiscourseConnector
from danswer.connectors.document360.connector import Document360Connector
from danswer.connectors.dropbox.connector import DropboxConnector
from danswer.connectors.file.connector import LocalFileConnector
from danswer.connectors.github.connector import GithubConnector
from danswer.connectors.gitlab.connector import GitlabConnector
@@ -23,17 +28,23 @@ from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.linear.connector import LinearConnector
from danswer.connectors.loopio.connector import LoopioConnector
from danswer.connectors.mediawiki.wiki import MediaWikiConnector
from danswer.connectors.models import InputType
from danswer.connectors.notion.connector import NotionConnector
from danswer.connectors.productboard.connector import ProductboardConnector
from danswer.connectors.requesttracker.connector import RequestTrackerConnector
from danswer.connectors.salesforce.connector import SalesforceConnector
from danswer.connectors.sharepoint.connector import SharepointConnector
from danswer.connectors.slab.connector import SlabConnector
from danswer.connectors.slack.connector import SlackPollConnector
from danswer.connectors.slack.load_connector import SlackLoadConnector
from danswer.connectors.teams.connector import TeamsConnector
from danswer.connectors.web.connector import WebConnector
from danswer.connectors.wikipedia.connector import WikipediaConnector
from danswer.connectors.zendesk.connector import ZendeskConnector
from danswer.connectors.zulip.connector import ZulipConnector
from danswer.db.credentials import backend_update_credential_json
from danswer.db.models import Credential
class ConnectorMissingException(Exception):
@@ -71,9 +82,19 @@ def identify_connector_class(
DocumentSource.GOOGLE_SITES: GoogleSitesConnector,
DocumentSource.ZENDESK: ZendeskConnector,
DocumentSource.LOOPIO: LoopioConnector,
DocumentSource.DROPBOX: DropboxConnector,
DocumentSource.SHAREPOINT: SharepointConnector,
DocumentSource.TEAMS: TeamsConnector,
DocumentSource.SALESFORCE: SalesforceConnector,
DocumentSource.DISCOURSE: DiscourseConnector,
DocumentSource.AXERO: AxeroConnector,
DocumentSource.CLICKUP: ClickupConnector,
DocumentSource.MEDIAWIKI: MediaWikiConnector,
DocumentSource.WIKIPEDIA: WikipediaConnector,
DocumentSource.S3: BlobStorageConnector,
DocumentSource.R2: BlobStorageConnector,
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
DocumentSource.OCI_STORAGE: BlobStorageConnector,
}
connector_by_source = connector_map.get(source, {})
@@ -99,7 +120,6 @@ def identify_connector_class(
raise ConnectorMissingException(
f"Connector for source={source} does not accept input_type={input_type}"
)
return connector
@@ -107,10 +127,14 @@ def instantiate_connector(
source: DocumentSource,
input_type: InputType,
connector_specific_config: dict[str, Any],
credentials: dict[str, Any],
) -> tuple[BaseConnector, dict[str, Any] | None]:
credential: Credential,
db_session: Session,
) -> BaseConnector:
connector_class = identify_connector_class(source, input_type)
connector = connector_class(**connector_specific_config)
new_credentials = connector.load_credentials(credentials)
new_credentials = connector.load_credentials(credential.credential_json)
return connector, new_credentials
if new_credentials is not None:
backend_update_credential_json(credential, new_credentials, db_session)
return connector

View File

@@ -69,7 +69,9 @@ def _process_file(
if is_text_file_extension(file_name):
encoding = detect_encoding(file)
file_content_raw, file_metadata = read_text_file(file, encoding=encoding)
file_content_raw, file_metadata = read_text_file(
file, encoding=encoding, ignore_danswer_metadata=False
)
# Using the PDF reader function directly to pass in password cleanly
elif extension == ".pdf":
@@ -83,8 +85,18 @@ def _process_file(
all_metadata = {**metadata, **file_metadata} if metadata else file_metadata
# add a prefix to avoid conflicts with other connectors
doc_id = f"FILE_CONNECTOR__{file_name}"
if metadata:
doc_id = metadata.get("document_id") or doc_id
# If this is set, we will show this in the UI as the "name" of the file
file_display_name_override = all_metadata.get("file_display_name")
file_display_name = all_metadata.get("file_display_name") or os.path.basename(
file_name
)
title = (
all_metadata["title"] or "" if "title" in all_metadata else file_display_name
)
time_updated = all_metadata.get("time_updated", datetime.now(timezone.utc))
if isinstance(time_updated, str):
@@ -99,6 +111,7 @@ def _process_file(
for k, v in all_metadata.items()
if k
not in [
"document_id",
"time_updated",
"doc_updated_at",
"link",
@@ -106,6 +119,7 @@ def _process_file(
"secondary_owners",
"filename",
"file_display_name",
"title",
]
}
@@ -124,13 +138,13 @@ def _process_file(
return [
Document(
id=f"FILE_CONNECTOR__{file_name}", # add a prefix to avoid conflicts with other connectors
id=doc_id,
sections=[
Section(link=all_metadata.get("link"), text=file_content_raw.strip())
],
source=DocumentSource.FILE,
semantic_identifier=file_display_name_override
or os.path.basename(file_name),
semantic_identifier=file_display_name,
title=title,
doc_updated_at=final_time_updated,
primary_owners=p_owners,
secondary_owners=s_owners,

View File

@@ -1,4 +1,6 @@
import fnmatch
import itertools
from collections import deque
from collections.abc import Iterable
from collections.abc import Iterator
from datetime import datetime
@@ -6,7 +8,10 @@ from datetime import timezone
from typing import Any
import gitlab
import pytz
from gitlab.v4.objects import Project
from danswer.configs.app_configs import GITLAB_CONNECTOR_INCLUDE_CODE_FILES
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
@@ -19,7 +24,13 @@ from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.logger import setup_logger
# List of directories/Files to exclude
exclude_patterns = [
"logs",
".github/",
".gitlab/",
".pre-commit-config.yaml",
]
logger = setup_logger()
@@ -72,6 +83,37 @@ def _convert_issue_to_document(issue: Any) -> Document:
return doc
def _convert_code_to_document(
project: Project, file: Any, url: str, projectName: str, projectOwner: str
) -> Document:
file_content_obj = project.files.get(
file_path=file["path"], ref="master"
) # Replace 'master' with your branch name if needed
try:
file_content = file_content_obj.decode().decode("utf-8")
except UnicodeDecodeError:
file_content = file_content_obj.decode().decode("latin-1")
file_url = f"{url}/{projectOwner}/{projectName}/-/blob/master/{file['path']}" # Construct the file URL
doc = Document(
id=file["id"],
sections=[Section(link=file_url, text=file_content)],
source=DocumentSource.GITLAB,
semantic_identifier=file["name"],
doc_updated_at=datetime.now().replace(
tzinfo=timezone.utc
), # Use current time as updated_at
primary_owners=[], # Fill this as needed
metadata={"type": "CodeFile"},
)
return doc
def _should_exclude(path: str) -> bool:
"""Check if a path matches any of the exclude patterns."""
return any(fnmatch.fnmatch(path, pattern) for pattern in exclude_patterns)
class GitlabConnector(LoadConnector, PollConnector):
def __init__(
self,
@@ -81,6 +123,7 @@ class GitlabConnector(LoadConnector, PollConnector):
state_filter: str = "all",
include_mrs: bool = True,
include_issues: bool = True,
include_code_files: bool = GITLAB_CONNECTOR_INCLUDE_CODE_FILES,
) -> None:
self.project_owner = project_owner
self.project_name = project_name
@@ -88,6 +131,7 @@ class GitlabConnector(LoadConnector, PollConnector):
self.state_filter = state_filter
self.include_mrs = include_mrs
self.include_issues = include_issues
self.include_code_files = include_code_files
self.gitlab_client: gitlab.Gitlab | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
@@ -101,45 +145,80 @@ class GitlabConnector(LoadConnector, PollConnector):
) -> GenerateDocumentsOutput:
if self.gitlab_client is None:
raise ConnectorMissingCredentialError("Gitlab")
project = self.gitlab_client.projects.get(
project: gitlab.Project = self.gitlab_client.projects.get(
f"{self.project_owner}/{self.project_name}"
)
# Fetch code files
if self.include_code_files:
# Fetching using BFS as project.report_tree with recursion causing slow load
queue = deque([""]) # Start with the root directory
while queue:
current_path = queue.popleft()
files = project.repository_tree(path=current_path, all=True)
for file_batch in _batch_gitlab_objects(files, self.batch_size):
code_doc_batch: list[Document] = []
for file in file_batch:
if _should_exclude(file["path"]):
continue
if file["type"] == "blob":
code_doc_batch.append(
_convert_code_to_document(
project,
file,
self.gitlab_client.url,
self.project_name,
self.project_owner,
)
)
elif file["type"] == "tree":
queue.append(file["path"])
if code_doc_batch:
yield code_doc_batch
if self.include_mrs:
merge_requests = project.mergerequests.list(
state=self.state_filter, order_by="updated_at", sort="desc"
)
for mr_batch in _batch_gitlab_objects(merge_requests, self.batch_size):
doc_batch: list[Document] = []
mr_doc_batch: list[Document] = []
for mr in mr_batch:
mr.updated_at = datetime.strptime(
mr.updated_at, "%Y-%m-%dT%H:%M:%S.%fZ"
mr.updated_at, "%Y-%m-%dT%H:%M:%S.%f%z"
)
if start is not None and mr.updated_at < start:
yield doc_batch
if start is not None and mr.updated_at < start.replace(
tzinfo=pytz.UTC
):
yield mr_doc_batch
return
if end is not None and mr.updated_at > end:
if end is not None and mr.updated_at > end.replace(tzinfo=pytz.UTC):
continue
doc_batch.append(_convert_merge_request_to_document(mr))
yield doc_batch
mr_doc_batch.append(_convert_merge_request_to_document(mr))
yield mr_doc_batch
if self.include_issues:
issues = project.issues.list(state=self.state_filter)
for issue_batch in _batch_gitlab_objects(issues, self.batch_size):
doc_batch = []
issue_doc_batch: list[Document] = []
for issue in issue_batch:
issue.updated_at = datetime.strptime(
issue.updated_at, "%Y-%m-%dT%H:%M:%S.%fZ"
issue.updated_at, "%Y-%m-%dT%H:%M:%S.%f%z"
)
if start is not None and issue.updated_at < start:
yield doc_batch
return
if end is not None and issue.updated_at > end:
continue
doc_batch.append(_convert_issue_to_document(issue))
yield doc_batch
if start is not None:
start = start.replace(tzinfo=pytz.UTC)
if issue.updated_at < start:
yield issue_doc_batch
return
if end is not None:
end = end.replace(tzinfo=pytz.UTC)
if issue.updated_at > end:
continue
issue_doc_batch.append(_convert_issue_to_document(issue))
yield issue_doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_from_gitlab()
@@ -163,11 +242,12 @@ if __name__ == "__main__":
state_filter="all",
include_mrs=True,
include_issues=True,
include_code_files=GITLAB_CONNECTOR_INCLUDE_CODE_FILES,
)
connector.load_credentials(
{
"github_access_token": os.environ["GITLAB_ACCESS_TOKEN"],
"gitlab_access_token": os.environ["GITLAB_ACCESS_TOKEN"],
"gitlab_url": os.environ["GITLAB_URL"],
}
)

View File

@@ -3,7 +3,8 @@ from typing import Any
from typing import cast
from typing import Dict
from google.auth.credentials import Credentials # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient import discovery # type: ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE
@@ -36,7 +37,7 @@ logger = setup_logger()
class GmailConnector(LoadConnector, PollConnector):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
self.creds: Credentials | None = None
self.creds: OAuthCredentials | ServiceAccountCredentials | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
"""Checks for two different types of credentials.
@@ -45,7 +46,7 @@ class GmailConnector(LoadConnector, PollConnector):
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.
"""
creds = None
creds: OAuthCredentials | ServiceAccountCredentials | None = None
new_creds_dict = None
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
access_token_json_str = cast(
@@ -74,7 +75,7 @@ class GmailConnector(LoadConnector, PollConnector):
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
)
if delegated_user_email:
creds = creds.with_subject(delegated_user_email) if creds else None
creds = creds.with_subject(delegated_user_email) if creds else None # type: ignore
if creds is None:
raise PermissionError(

View File

@@ -8,7 +8,8 @@ from itertools import chain
from typing import Any
from typing import cast
from google.auth.credentials import Credentials # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient import discovery # type: ignore
from googleapiclient.errors import HttpError # type: ignore
@@ -41,6 +42,7 @@ from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import docx_to_text
from danswer.file_processing.extract_file_text import pdf_to_text
from danswer.file_processing.extract_file_text import pptx_to_text
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
@@ -56,6 +58,10 @@ class GDriveMimeType(str, Enum):
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
PDF = "application/pdf"
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
PPT = "application/vnd.google-apps.presentation"
POWERPOINT = (
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
)
GoogleDriveFileType = dict[str, Any]
@@ -324,6 +330,12 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
elif mime_type == GDriveMimeType.PDF.value:
response = service.files().get_media(fileId=file["id"]).execute()
return pdf_to_text(file=io.BytesIO(response))
elif mime_type == GDriveMimeType.POWERPOINT.value:
response = service.files().get_media(fileId=file["id"]).execute()
return pptx_to_text(file=io.BytesIO(response))
elif mime_type == GDriveMimeType.PPT.value:
response = service.files().get_media(fileId=file["id"]).execute()
return pptx_to_text(file=io.BytesIO(response))
return UNSUPPORTED_FILE_TYPE_CONTENT
@@ -346,7 +358,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
self.follow_shortcuts = follow_shortcuts
self.only_org_public = only_org_public
self.continue_on_failure = continue_on_failure
self.creds: Credentials | None = None
self.creds: OAuthCredentials | ServiceAccountCredentials | None = None
@staticmethod
def _process_folder_paths(
@@ -387,7 +399,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.
"""
creds = None
creds: OAuthCredentials | ServiceAccountCredentials | None = None
new_creds_dict = None
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
access_token_json_str = cast(
@@ -416,7 +428,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
)
if delegated_user_email:
creds = creds.with_subject(delegated_user_email) if creds else None
creds = creds.with_subject(delegated_user_email) if creds else None # type: ignore
if creds is None:
raise PermissionError(
@@ -461,6 +473,11 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
doc_batch = []
for file in files_batch:
try:
# Skip files that are shortcuts
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
logger.info("Ignoring Drive Shortcut Filetype")
continue
if self.only_org_public:
if "permissions" not in file:
continue

View File

@@ -50,6 +50,12 @@ class PollConnector(BaseConnector):
raise NotImplementedError
class IdConnector(BaseConnector):
@abc.abstractmethod
def retrieve_all_source_ids(self) -> set[str]:
raise NotImplementedError
# Event driven
class EventConnector(BaseConnector):
@abc.abstractmethod

View File

@@ -0,0 +1,166 @@
from __future__ import annotations
import builtins
import functools
import itertools
from typing import Any
from unittest import mock
from urllib.parse import urlparse
from urllib.parse import urlunparse
from pywikibot import family # type: ignore[import-untyped]
from pywikibot import pagegenerators # type: ignore[import-untyped]
from pywikibot.scripts import generate_family_file # type: ignore[import-untyped]
from pywikibot.scripts.generate_user_files import pywikibot # type: ignore[import-untyped]
from danswer.utils.logger import setup_logger
logger = setup_logger()
@mock.patch.object(
builtins, "print", lambda *args: logger.info("\t".join(map(str, args)))
)
class FamilyFileGeneratorInMemory(generate_family_file.FamilyFileGenerator):
"""A subclass of FamilyFileGenerator that writes the family file to memory instead of to disk."""
def __init__(
self,
url: str,
name: str,
dointerwiki: str | bool = True,
verify: str | bool = True,
):
"""Initialize the FamilyFileGeneratorInMemory."""
url_parse = urlparse(url, "https")
if not url_parse.netloc and url_parse.path:
url = urlunparse(
(url_parse.scheme, url_parse.path, url_parse.netloc, *url_parse[3:])
)
else:
url = urlunparse(url_parse)
assert isinstance(url, str)
if any(x not in generate_family_file.NAME_CHARACTERS for x in name):
raise ValueError(
'ERROR: Name of family "{}" must be ASCII letters and digits [a-zA-Z0-9]',
name,
)
if isinstance(dointerwiki, bool):
dointerwiki = "Y" if dointerwiki else "N"
assert isinstance(dointerwiki, str)
if isinstance(verify, bool):
verify = "Y" if verify else "N"
assert isinstance(verify, str)
super().__init__(url, name, dointerwiki, verify)
self.family_definition: type[family.Family] | None = None
def get_params(self) -> bool:
"""Get the parameters for the family class definition.
This override prevents the method from prompting the user for input (which would be impossible in this context).
We do all the input validation in the constructor.
"""
return True
def writefile(self, verify: Any) -> None:
"""Write the family file.
This overrides the method in the parent class to write the family definition to memory instead of to disk.
Args:
verify: unused argument necessary to match the signature of the method in the parent class.
"""
code_hostname_pairs = {
f"{k}": f"{urlparse(w.server).netloc}" for k, w in self.wikis.items()
}
code_path_pairs = {f"{k}": f"{w.scriptpath}" for k, w in self.wikis.items()}
code_protocol_pairs = {
f"{k}": f"{urlparse(w.server).scheme}" for k, w in self.wikis.items()
}
class Family(family.Family): # noqa: D101
"""The family definition for the wiki."""
name = "%(name)s"
langs = code_hostname_pairs
def scriptpath(self, code: str) -> str:
return code_path_pairs[code]
def protocol(self, code: str) -> str:
return code_protocol_pairs[code]
self.family_definition = Family
@functools.lru_cache(maxsize=None)
def generate_family_class(url: str, name: str) -> type[family.Family]:
"""Generate a family file for a given URL and name.
Args:
url: The URL of the wiki.
name: The short name of the wiki (customizable by the user).
Returns:
The family definition.
Raises:
ValueError: If the family definition was not generated.
"""
generator = FamilyFileGeneratorInMemory(url, name, "Y", "Y")
generator.run()
if generator.family_definition is None:
raise ValueError("Family definition was not generated.")
return generator.family_definition
def family_class_dispatch(url: str, name: str) -> type[family.Family]:
"""Find or generate a family class for a given URL and name.
Args:
url: The URL of the wiki.
name: The short name of the wiki (customizable by the user).
"""
if "wikipedia" in url:
import pywikibot.families.wikipedia_family # type: ignore[import-untyped]
return pywikibot.families.wikipedia_family.Family
# TODO: Support additional families pre-defined in `pywikibot.families.*_family.py` files
return generate_family_class(url, name)
if __name__ == "__main__":
url = "fallout.fandom.com/wiki/Fallout_Wiki"
name = "falloutfandom"
categories: list[str] = []
pages = ["Fallout: New Vegas"]
recursion_depth = 1
family_type = generate_family_class(url, name)
site = pywikibot.Site(fam=family_type(), code="en")
categories = [
pywikibot.Category(site, f"Category:{category.replace(' ', '_')}")
for category in categories
]
pages = [pywikibot.Page(site, page) for page in pages]
all_pages = itertools.chain(
pages,
*[
pagegenerators.CategorizedPageGenerator(category, recurse=recursion_depth)
for category in categories
],
)
for page in all_pages:
print(page.title())
print(page.text[:1000])

View File

@@ -0,0 +1,225 @@
from __future__ import annotations
import datetime
import itertools
from collections.abc import Generator
from typing import Any
from typing import ClassVar
import pywikibot.time # type: ignore[import-untyped]
from pywikibot import pagegenerators # type: ignore[import-untyped]
from pywikibot import textlib # type: ignore[import-untyped]
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.mediawiki.family import family_class_dispatch
from danswer.connectors.models import Document
from danswer.connectors.models import Section
def pywikibot_timestamp_to_utc_datetime(
timestamp: pywikibot.time.Timestamp,
) -> datetime.datetime:
"""Convert a pywikibot timestamp to a datetime object in UTC.
Args:
timestamp: The pywikibot timestamp to convert.
Returns:
A datetime object in UTC.
"""
return datetime.datetime.astimezone(timestamp, tz=datetime.timezone.utc)
def get_doc_from_page(
page: pywikibot.Page, site: pywikibot.Site | None, source_type: DocumentSource
) -> Document:
"""Generate Danswer Document from a MediaWiki page object.
Args:
page: Page from a MediaWiki site.
site: MediaWiki site (used to parse the sections of the page using the site template, if available).
source_type: Source of the document.
Returns:
Generated document.
"""
page_text = page.text
sections_extracted: textlib.Content = textlib.extract_sections(page_text, site)
sections = [
Section(
link=f"{page.full_url()}#" + section.heading.replace(" ", "_"),
text=section.title + section.content,
)
for section in sections_extracted.sections
]
sections.append(
Section(
link=page.full_url(),
text=sections_extracted.header,
)
)
return Document(
source=source_type,
title=page.title(),
doc_updated_at=pywikibot_timestamp_to_utc_datetime(
page.latest_revision.timestamp
),
sections=sections,
semantic_identifier=page.title(),
metadata={"categories": [category.title() for category in page.categories()]},
id=page.pageid,
)
class MediaWikiConnector(LoadConnector, PollConnector):
"""A connector for MediaWiki wikis.
Args:
hostname: The hostname of the wiki.
categories: The categories to include in the index.
pages: The pages to include in the index.
recurse_depth: The depth to recurse into categories. -1 means unbounded recursion.
connector_name: The name of the connector.
language_code: The language code of the wiki.
batch_size: The batch size for loading documents.
Raises:
ValueError: If `recurse_depth` is not an integer greater than or equal to -1.
"""
document_source_type: ClassVar[DocumentSource] = DocumentSource.MEDIAWIKI
"""DocumentSource type for all documents generated by instances of this class. Can be overridden for connectors
tailored for specific sites."""
def __init__(
self,
hostname: str,
categories: list[str],
pages: list[str],
recurse_depth: int,
connector_name: str,
language_code: str = "en",
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
if recurse_depth < -1:
raise ValueError(
f"recurse_depth must be an integer greater than or equal to -1. Got {recurse_depth} instead."
)
# -1 means infinite recursion, which `pywikibot` will only do with `True`
self.recurse_depth: bool | int = True if recurse_depth == -1 else recurse_depth
self.batch_size = batch_size
# short names can only have ascii letters and digits
self.connector_name = connector_name
connector_name = "".join(ch for ch in connector_name if ch.isalnum())
self.family = family_class_dispatch(hostname, connector_name)()
self.site = pywikibot.Site(fam=self.family, code=language_code)
self.categories = [
pywikibot.Category(self.site, f"Category:{category.replace(' ', '_')}")
for category in categories
]
self.pages = [pywikibot.Page(self.site, page) for page in pages]
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Load credentials for a MediaWiki site.
Note:
For most read-only operations, MediaWiki API credentials are not necessary.
This method can be overridden in the event that a particular MediaWiki site
requires credentials.
"""
return None
def _get_doc_batch(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Generator[list[Document], None, None]:
"""Request batches of pages from a MediaWiki site.
Args:
start: The beginning of the time period of pages to request.
end: The end of the time period of pages to request.
Yields:
Lists of Documents containing each parsed page in a batch.
"""
doc_batch: list[Document] = []
# Pywikibot can handle batching for us, including only loading page contents when we finally request them.
category_pages = [
pagegenerators.PreloadingGenerator(
pagegenerators.EdittimeFilterPageGenerator(
pagegenerators.CategorizedPageGenerator(
category, recurse=self.recurse_depth
),
last_edit_start=datetime.datetime.fromtimestamp(start)
if start
else None,
last_edit_end=datetime.datetime.fromtimestamp(end) if end else None,
),
groupsize=self.batch_size,
)
for category in self.categories
]
# Since we can specify both individual pages and categories, we need to iterate over all of them.
all_pages = itertools.chain(self.pages, *category_pages)
for page in all_pages:
doc_batch.append(
get_doc_from_page(page, self.site, self.document_source_type)
)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
"""Load all documents from the source.
Returns:
A generator of documents.
"""
return self.poll_source(None, None)
def poll_source(
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
"""Poll the source for new documents.
Args:
start: The start of the time range to poll.
end: The end of the time range to poll.
Returns:
A generator of documents.
"""
return self._get_doc_batch(start, end)
if __name__ == "__main__":
HOSTNAME = "fallout.fandom.com"
test_connector = MediaWikiConnector(
connector_name="Fallout",
hostname=HOSTNAME,
categories=["Fallout:_New_Vegas_factions"],
pages=["Fallout: New Vegas"],
recurse_depth=1,
)
all_docs = list(test_connector.load_from_state())
print("All docs", all_docs)
current = datetime.datetime.now().timestamp()
one_day_ago = current - 30 * 24 * 60 * 60 # 30 days
latest_docs = list(test_connector.poll_source(one_day_ago, current))
print("Latest docs", latest_docs)

View File

@@ -6,6 +6,7 @@ from pydantic import BaseModel
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import INDEX_SEPARATOR
from danswer.configs.constants import RETURN_SEPARATOR
from danswer.utils.text_processing import make_url_compatible
@@ -13,6 +14,7 @@ class InputType(str, Enum):
LOAD_STATE = "load_state" # e.g. loading a current full state or a save state, such as from a file
POLL = "poll" # e.g. calling an API to get all documents in the last hour
EVENT = "event" # e.g. registered an endpoint as a listener, and processing connector events
PRUNE = "prune"
class ConnectorMissingCredentialError(PermissionError):
@@ -116,7 +118,12 @@ class DocumentBase(BaseModel):
# If title is explicitly empty, return a None here for embedding purposes
if self.title == "":
return None
return self.semantic_identifier if self.title is None else self.title
replace_chars = set(RETURN_SEPARATOR)
title = self.semantic_identifier if self.title is None else self.title
for char in replace_chars:
title = title.replace(char, " ")
title = title.strip()
return title
def get_metadata_str_attributes(self) -> list[str] | None:
if not self.metadata:

View File

@@ -368,7 +368,7 @@ class NotionConnector(LoadConnector, PollConnector):
compare_time = time.mktime(
time.strptime(page[filter_field], "%Y-%m-%dT%H:%M:%S.000Z")
)
if compare_time <= end or compare_time > start:
if compare_time > start and compare_time <= end:
filtered_pages += [NotionPage(**page)]
return filtered_pages

View File

@@ -0,0 +1,274 @@
import os
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import Any
from simple_salesforce import Salesforce
from simple_salesforce import SFType
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import IdConnector
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.salesforce.utils import extract_dict_text
from danswer.utils.logger import setup_logger
DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
MAX_QUERY_LENGTH = 10000 # max query length is 20,000 characters
ID_PREFIX = "SALESFORCE_"
logger = setup_logger()
class SalesforceConnector(LoadConnector, PollConnector, IdConnector):
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
requested_objects: list[str] = [],
) -> None:
self.batch_size = batch_size
self.sf_client: Salesforce | None = None
self.parent_object_list = (
[obj.capitalize() for obj in requested_objects]
if requested_objects
else DEFAULT_PARENT_OBJECT_TYPES
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.sf_client = Salesforce(
username=credentials["sf_username"],
password=credentials["sf_password"],
security_token=credentials["sf_security_token"],
)
return None
def _get_sf_type_object_json(self, type_name: str) -> Any:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
sf_object = SFType(
type_name, self.sf_client.session_id, self.sf_client.sf_instance
)
return sf_object.describe()
def _get_name_from_id(self, id: str) -> str:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
try:
user_object_info = self.sf_client.query(
f"SELECT Name FROM User WHERE Id = '{id}'"
)
name = user_object_info.get("Records", [{}])[0].get("Name", "Null User")
return name
except Exception:
logger.warning(f"Couldnt find name for object id: {id}")
return "Null User"
def _convert_object_instance_to_document(
self, object_dict: dict[str, Any]
) -> Document:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
salesforce_id = object_dict["Id"]
danswer_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
extracted_link = f"https://{self.sf_client.sf_instance}/{salesforce_id}"
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
extracted_object_text = extract_dict_text(object_dict)
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
extracted_primary_owners = [
BasicExpertInfo(
display_name=self._get_name_from_id(object_dict["LastModifiedById"])
)
]
doc = Document(
id=danswer_salesforce_id,
sections=[Section(link=extracted_link, text=extracted_object_text)],
source=DocumentSource.SALESFORCE,
semantic_identifier=extracted_semantic_identifier,
doc_updated_at=extracted_doc_updated_at,
primary_owners=extracted_primary_owners,
metadata={},
)
return doc
def _is_valid_child_object(self, child_relationship: dict) -> bool:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
if not child_relationship["childSObject"]:
return False
if not child_relationship["relationshipName"]:
return False
sf_type = child_relationship["childSObject"]
object_description = self._get_sf_type_object_json(sf_type)
if not object_description["queryable"]:
return False
try:
query = f"SELECT Count() FROM {sf_type} LIMIT 1"
result = self.sf_client.query(query)
if result["totalSize"] == 0:
return False
except Exception as e:
logger.warning(f"Object type {sf_type} doesn't support query: {e}")
return False
if child_relationship["field"]:
if child_relationship["field"] == "RelatedToId":
return False
else:
return False
return True
def _get_all_children_of_sf_type(self, sf_type: str) -> list[dict]:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
object_description = self._get_sf_type_object_json(sf_type)
children_objects: list[dict] = []
for child_relationship in object_description["childRelationships"]:
if self._is_valid_child_object(child_relationship):
children_objects.append(
{
"relationship_name": child_relationship["relationshipName"],
"object_type": child_relationship["childSObject"],
}
)
return children_objects
def _get_all_fields_for_sf_type(self, sf_type: str) -> list[str]:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
object_description = self._get_sf_type_object_json(sf_type)
fields = [
field.get("name")
for field in object_description["fields"]
if field.get("type", "base64") != "base64"
]
return fields
def _generate_query_per_parent_type(self, parent_sf_type: str) -> Iterator[str]:
"""
This function takes in an object_type and generates query(s) designed to grab
information associated to objects of that type.
It does that by getting all the fields of the parent object type.
Then it gets all the child objects of that object type and all the fields of
those children as well.
"""
parent_fields = self._get_all_fields_for_sf_type(parent_sf_type)
child_sf_types = self._get_all_children_of_sf_type(parent_sf_type)
query = f"SELECT {', '.join(parent_fields)}"
for child_object_dict in child_sf_types:
fields = self._get_all_fields_for_sf_type(child_object_dict["object_type"])
query_addition = f", \n(SELECT {', '.join(fields)} FROM {child_object_dict['relationship_name']})"
if len(query_addition) + len(query) > MAX_QUERY_LENGTH:
query += f"\n FROM {parent_sf_type}"
yield query
query = "SELECT Id" + query_addition
else:
query += query_addition
query += f"\n FROM {parent_sf_type}"
yield query
def _fetch_from_salesforce(
self,
start: datetime | None = None,
end: datetime | None = None,
) -> GenerateDocumentsOutput:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
doc_batch: list[Document] = []
for parent_object_type in self.parent_object_list:
logger.debug(f"Processing: {parent_object_type}")
query_results: dict = {}
for query in self._generate_query_per_parent_type(parent_object_type):
if start is not None and end is not None:
if start and start.tzinfo is None:
start = start.replace(tzinfo=timezone.utc)
if end and end.tzinfo is None:
end = end.replace(tzinfo=timezone.utc)
query += f" WHERE LastModifiedDate > {start.isoformat()} AND LastModifiedDate < {end.isoformat()}"
query_result = self.sf_client.query_all(query)
for record_dict in query_result["records"]:
query_results.setdefault(record_dict["Id"], {}).update(record_dict)
logger.info(
f"Number of {parent_object_type} Objects processed: {len(query_results)}"
)
for combined_object_dict in query_results.values():
doc_batch.append(
self._convert_object_instance_to_document(combined_object_dict)
)
if len(doc_batch) > self.batch_size:
yield doc_batch
doc_batch = []
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_from_salesforce()
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
start_datetime = datetime.utcfromtimestamp(start)
end_datetime = datetime.utcfromtimestamp(end)
return self._fetch_from_salesforce(start=start_datetime, end=end_datetime)
def retrieve_all_source_ids(self) -> set[str]:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
all_retrieved_ids: set[str] = set()
for parent_object_type in self.parent_object_list:
query = f"SELECT Id FROM {parent_object_type}"
query_result = self.sf_client.query_all(query)
all_retrieved_ids.update(
f"{ID_PREFIX}{instance_dict.get('Id', '')}"
for instance_dict in query_result["records"]
)
return all_retrieved_ids
if __name__ == "__main__":
connector = SalesforceConnector(
requested_objects=os.environ["REQUESTED_OBJECTS"].split(",")
)
connector.load_credentials(
{
"sf_username": os.environ["SF_USERNAME"],
"sf_password": os.environ["SF_PASSWORD"],
"sf_security_token": os.environ["SF_SECURITY_TOKEN"],
}
)
document_batches = connector.load_from_state()
print(next(document_batches))

View File

@@ -0,0 +1,66 @@
import re
from typing import Union
SF_JSON_FILTER = r"Id$|Date$|stamp$|url$"
def _clean_salesforce_dict(data: Union[dict, list]) -> Union[dict, list]:
if isinstance(data, dict):
if "records" in data.keys():
data = data["records"]
if isinstance(data, dict):
if "attributes" in data.keys():
if isinstance(data["attributes"], dict):
data.update(data.pop("attributes"))
if isinstance(data, dict):
filtered_dict = {}
for key, value in data.items():
if not re.search(SF_JSON_FILTER, key, re.IGNORECASE):
if "__c" in key: # remove the custom object indicator for display
key = key[:-3]
if isinstance(value, (dict, list)):
filtered_value = _clean_salesforce_dict(value)
if filtered_value: # Only add non-empty dictionaries or lists
filtered_dict[key] = filtered_value
elif value is not None:
filtered_dict[key] = value
return filtered_dict
elif isinstance(data, list):
filtered_list = []
for item in data:
if isinstance(item, (dict, list)):
filtered_item = _clean_salesforce_dict(item)
if filtered_item: # Only add non-empty dictionaries or lists
filtered_list.append(filtered_item)
elif item is not None:
filtered_list.append(filtered_item)
return filtered_list
else:
return data
def _json_to_natural_language(data: Union[dict, list], indent: int = 0) -> str:
result = []
indent_str = " " * indent
if isinstance(data, dict):
for key, value in data.items():
if isinstance(value, (dict, list)):
result.append(f"{indent_str}{key}:")
result.append(_json_to_natural_language(value, indent + 2))
else:
result.append(f"{indent_str}{key}: {value}")
elif isinstance(data, list):
for item in data:
result.append(_json_to_natural_language(item, indent))
else:
result.append(f"{indent_str}{data}")
return "\n".join(result)
def extract_dict_text(raw_dict: dict) -> str:
processed_dict = _clean_salesforce_dict(raw_dict)
natural_language_dict = _json_to_natural_language(processed_dict)
return natural_language_dict

View File

@@ -1,8 +1,11 @@
import io
import os
from dataclasses import dataclass
from dataclasses import field
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import Optional
import msal # type: ignore
from office365.graph_client import GraphClient # type: ignore
@@ -19,44 +22,45 @@ from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import docx_to_text
from danswer.file_processing.extract_file_text import file_io_to_text
from danswer.file_processing.extract_file_text import is_text_file_extension
from danswer.file_processing.extract_file_text import pdf_to_text
from danswer.file_processing.extract_file_text import pptx_to_text
from danswer.file_processing.extract_file_text import xlsx_to_text
from danswer.file_processing.extract_file_text import extract_file_text
from danswer.utils.logger import setup_logger
UNSUPPORTED_FILE_TYPE_CONTENT = "" # idea copied from the google drive side of things
logger = setup_logger()
def get_text_from_xlsx_driveitem(driveitem_object: DriveItem) -> str:
file_content = driveitem_object.get_content().execute_query().value
return xlsx_to_text(file=io.BytesIO(file_content))
@dataclass
class SiteData:
url: str | None
folder: Optional[str]
sites: list = field(default_factory=list)
driveitems: list = field(default_factory=list)
def get_text_from_docx_driveitem(driveitem_object: DriveItem) -> str:
file_content = driveitem_object.get_content().execute_query().value
return docx_to_text(file=io.BytesIO(file_content))
def _convert_driveitem_to_document(
driveitem: DriveItem,
) -> Document:
file_text = extract_file_text(
file_name=driveitem.name,
file=io.BytesIO(driveitem.get_content().execute_query().value),
break_on_unprocessable=False,
)
def get_text_from_pdf_driveitem(driveitem_object: DriveItem) -> str:
file_content = driveitem_object.get_content().execute_query().value
file_text = pdf_to_text(file=io.BytesIO(file_content))
return file_text
def get_text_from_txt_driveitem(driveitem_object: DriveItem) -> str:
file_content: bytes = driveitem_object.get_content().execute_query().value
return file_io_to_text(file=io.BytesIO(file_content))
def get_text_from_pptx_driveitem(driveitem_object: DriveItem) -> str:
file_content = driveitem_object.get_content().execute_query().value
return pptx_to_text(file=io.BytesIO(file_content))
doc = Document(
id=driveitem.id,
sections=[Section(link=driveitem.web_url, text=file_text)],
source=DocumentSource.SHAREPOINT,
semantic_identifier=driveitem.name,
doc_updated_at=driveitem.last_modified_datetime.replace(tzinfo=timezone.utc),
primary_owners=[
BasicExpertInfo(
display_name=driveitem.last_modified_by.user.displayName,
email=driveitem.last_modified_by.user.email,
)
],
metadata={},
)
return doc
class SharepointConnector(LoadConnector, PollConnector):
@@ -67,22 +71,112 @@ class SharepointConnector(LoadConnector, PollConnector):
) -> None:
self.batch_size = batch_size
self.graph_client: GraphClient | None = None
self.requested_site_list: list[str] = sites
self.site_data: list[SiteData] = self._extract_site_and_folder(sites)
@staticmethod
def _extract_site_and_folder(site_urls: list[str]) -> list[SiteData]:
site_data_list = []
for url in site_urls:
parts = url.strip().split("/")
if "sites" in parts:
sites_index = parts.index("sites")
site_url = "/".join(parts[: sites_index + 2])
folder = (
parts[sites_index + 2] if len(parts) > sites_index + 2 else None
)
site_data_list.append(
SiteData(url=site_url, folder=folder, sites=[], driveitems=[])
)
return site_data_list
def _populate_sitedata_driveitems(
self,
start: datetime | None = None,
end: datetime | None = None,
) -> None:
filter_str = ""
if start is not None and end is not None:
filter_str = f"last_modified_datetime ge {start.isoformat()} and last_modified_datetime le {end.isoformat()}"
for element in self.site_data:
sites: list[Site] = []
for site in element.sites:
site_sublist = site.lists.get().execute_query()
sites.extend(site_sublist)
for site in sites:
try:
query = site.drive.root.get_files(True, 1000)
if filter_str:
query = query.filter(filter_str)
driveitems = query.execute_query()
if element.folder:
filtered_driveitems = [
item
for item in driveitems
if element.folder in item.parent_reference.path
]
element.driveitems.extend(filtered_driveitems)
else:
element.driveitems.extend(driveitems)
except Exception:
# Sites include things that do not contain .drive.root so this fails
# but this is fine, as there are no actually documents in those
pass
def _populate_sitedata_sites(self) -> None:
if self.graph_client is None:
raise ConnectorMissingCredentialError("Sharepoint")
if self.site_data:
for element in self.site_data:
element.sites = [
self.graph_client.sites.get_by_url(element.url)
.get()
.execute_query()
]
else:
sites = self.graph_client.sites.get().execute_query()
self.site_data = [
SiteData(url=None, folder=None, sites=sites, driveitems=[])
]
def _fetch_from_sharepoint(
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
if self.graph_client is None:
raise ConnectorMissingCredentialError("Sharepoint")
self._populate_sitedata_sites()
self._populate_sitedata_driveitems(start=start, end=end)
# goes over all urls, converts them into Document objects and then yields them in batches
doc_batch: list[Document] = []
for element in self.site_data:
for driveitem in element.driveitems:
logger.debug(f"Processing: {driveitem.web_url}")
doc_batch.append(_convert_driveitem_to_document(driveitem))
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
yield doc_batch
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
aad_client_id = credentials["aad_client_id"]
aad_client_secret = credentials["aad_client_secret"]
aad_directory_id = credentials["aad_directory_id"]
sp_client_id = credentials["sp_client_id"]
sp_client_secret = credentials["sp_client_secret"]
sp_directory_id = credentials["sp_directory_id"]
def _acquire_token_func() -> dict[str, Any]:
"""
Acquire token via MSAL
"""
authority_url = f"https://login.microsoftonline.com/{aad_directory_id}"
authority_url = f"https://login.microsoftonline.com/{sp_directory_id}"
app = msal.ConfidentialClientApplication(
authority=authority_url,
client_id=aad_client_id,
client_credential=aad_client_secret,
client_id=sp_client_id,
client_credential=sp_client_secret,
)
token = app.acquire_token_for_client(
scopes=["https://graph.microsoft.com/.default"]
@@ -92,122 +186,6 @@ class SharepointConnector(LoadConnector, PollConnector):
self.graph_client = GraphClient(_acquire_token_func)
return None
def get_all_driveitem_objects(
self,
site_object_list: list[Site],
start: datetime | None = None,
end: datetime | None = None,
) -> list[DriveItem]:
filter_str = ""
if start is not None and end is not None:
filter_str = f"last_modified_datetime ge {start.isoformat()} and last_modified_datetime le {end.isoformat()}"
driveitem_list = []
for site_object in site_object_list:
site_list_objects = site_object.lists.get().execute_query()
for site_list_object in site_list_objects:
try:
query = site_list_object.drive.root.get_files(True)
if filter_str:
query = query.filter(filter_str)
driveitems = query.execute_query()
driveitem_list.extend(driveitems)
except Exception:
# Sites include things that do not contain .drive.root so this fails
# but this is fine, as there are no actually documents in those
pass
return driveitem_list
def get_all_site_objects(self) -> list[Site]:
if self.graph_client is None:
raise ConnectorMissingCredentialError("Sharepoint")
site_object_list: list[Site] = []
sites_object = self.graph_client.sites.get().execute_query()
if len(self.requested_site_list) > 0:
for requested_site in self.requested_site_list:
adjusted_string = "/" + requested_site.replace(" ", "")
for site_object in sites_object:
if site_object.web_url.endswith(adjusted_string):
site_object_list.append(site_object)
else:
site_object_list.extend(sites_object)
return site_object_list
def _fetch_from_sharepoint(
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
if self.graph_client is None:
raise ConnectorMissingCredentialError("Sharepoint")
site_object_list = self.get_all_site_objects()
driveitem_list = self.get_all_driveitem_objects(
site_object_list=site_object_list,
start=start,
end=end,
)
# goes over all urls, converts them into Document objects and then yields them in batches
doc_batch: list[Document] = []
batch_count = 0
for driveitem_object in driveitem_list:
logger.debug(f"Processing: {driveitem_object.web_url}")
doc_batch.append(
self.convert_driveitem_object_to_document(driveitem_object)
)
batch_count += 1
if batch_count >= self.batch_size:
yield doc_batch
batch_count = 0
doc_batch = []
yield doc_batch
def convert_driveitem_object_to_document(
self,
driveitem_object: DriveItem,
) -> Document:
file_text = self.extract_driveitem_text(driveitem_object)
doc = Document(
id=driveitem_object.id,
sections=[Section(link=driveitem_object.web_url, text=file_text)],
source=DocumentSource.SHAREPOINT,
semantic_identifier=driveitem_object.name,
doc_updated_at=driveitem_object.last_modified_datetime.replace(
tzinfo=timezone.utc
),
primary_owners=[
BasicExpertInfo(
display_name=driveitem_object.last_modified_by.user.displayName,
email=driveitem_object.last_modified_by.user.email,
)
],
metadata={},
)
return doc
def extract_driveitem_text(self, driveitem_object: DriveItem) -> str:
driveitem_name = driveitem_object.name
driveitem_text = UNSUPPORTED_FILE_TYPE_CONTENT
if driveitem_name.endswith(".docx"):
driveitem_text = get_text_from_docx_driveitem(driveitem_object)
elif driveitem_name.endswith(".pdf"):
driveitem_text = get_text_from_pdf_driveitem(driveitem_object)
elif driveitem_name.endswith(".xlsx"):
driveitem_text = get_text_from_xlsx_driveitem(driveitem_object)
elif driveitem_name.endswith(".pptx"):
driveitem_text = get_text_from_pptx_driveitem(driveitem_object)
elif is_text_file_extension(driveitem_name):
driveitem_text = get_text_from_txt_driveitem(driveitem_object)
return driveitem_text
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_from_sharepoint()
@@ -224,9 +202,9 @@ if __name__ == "__main__":
connector.load_credentials(
{
"aad_client_id": os.environ["AAD_CLIENT_ID"],
"aad_client_secret": os.environ["AAD_CLIENT_SECRET"],
"aad_directory_id": os.environ["AAD_CLIENT_DIRECTORY_ID"],
"sp_client_id": os.environ["SP_CLIENT_ID"],
"sp_client_secret": os.environ["SP_CLIENT_SECRET"],
"sp_directory_id": os.environ["SP_CLIENT_DIRECTORY_ID"],
}
)
document_batches = connector.load_from_state()

View File

@@ -0,0 +1,278 @@
import os
from datetime import datetime
from datetime import timezone
from typing import Any
import msal # type: ignore
from office365.graph_client import GraphClient # type: ignore
from office365.teams.channels.channel import Channel # type: ignore
from office365.teams.chats.messages.message import ChatMessage # type: ignore
from office365.teams.team import Team # type: ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.html_utils import parse_html_page_basic
from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_created_datetime(chat_message: ChatMessage) -> datetime:
# Extract the 'createdDateTime' value from the 'properties' dictionary and convert it to a datetime object
return time_str_to_utc(chat_message.properties["createdDateTime"])
def _extract_channel_members(channel: Channel) -> list[BasicExpertInfo]:
channel_members_list: list[BasicExpertInfo] = []
members = channel.members.get().execute_query()
for member in members:
channel_members_list.append(BasicExpertInfo(display_name=member.display_name))
return channel_members_list
def _get_threads_from_channel(
channel: Channel,
start: datetime | None = None,
end: datetime | None = None,
) -> list[list[ChatMessage]]:
# Ensure start and end are timezone-aware
if start and start.tzinfo is None:
start = start.replace(tzinfo=timezone.utc)
if end and end.tzinfo is None:
end = end.replace(tzinfo=timezone.utc)
query = channel.messages.get()
base_messages: list[ChatMessage] = query.execute_query()
threads: list[list[ChatMessage]] = []
for base_message in base_messages:
message_datetime = time_str_to_utc(
base_message.properties["lastModifiedDateTime"]
)
if start and message_datetime < start:
continue
if end and message_datetime > end:
continue
reply_query = base_message.replies.get_all()
replies = reply_query.execute_query()
# start a list containing the base message and its replies
thread: list[ChatMessage] = [base_message]
thread.extend(replies)
threads.append(thread)
return threads
def _get_channels_from_teams(
teams: list[Team],
) -> list[Channel]:
channels_list: list[Channel] = []
for team in teams:
query = team.channels.get()
channels = query.execute_query()
channels_list.extend(channels)
return channels_list
def _construct_semantic_identifier(channel: Channel, top_message: ChatMessage) -> str:
first_poster = (
top_message.properties.get("from", {})
.get("user", {})
.get("displayName", "Unknown User")
)
channel_name = channel.properties.get("displayName", "Unknown")
thread_subject = top_message.properties.get("subject", "Unknown")
snippet = parse_html_page_basic(top_message.body.content.rstrip())
snippet = snippet[:50] + "..." if len(snippet) > 50 else snippet
return f"{first_poster} in {channel_name} about {thread_subject}: {snippet}"
def _convert_thread_to_document(
channel: Channel,
thread: list[ChatMessage],
) -> Document | None:
if len(thread) == 0:
return None
most_recent_message_datetime: datetime | None = None
top_message = thread[0]
post_members_list: list[BasicExpertInfo] = []
thread_text = ""
sorted_thread = sorted(thread, key=get_created_datetime, reverse=True)
if sorted_thread:
most_recent_message = sorted_thread[0]
most_recent_message_datetime = time_str_to_utc(
most_recent_message.properties["createdDateTime"]
)
for message in thread:
# add text and a newline
if message.body.content:
message_text = parse_html_page_basic(message.body.content)
thread_text += message_text
# if it has a subject, that means its the top level post message, so grab its id, url, and subject
if message.properties["subject"]:
top_message = message
# check to make sure there is a valid display name
if message.properties["from"]:
if message.properties["from"]["user"]:
if message.properties["from"]["user"]["displayName"]:
message_sender = message.properties["from"]["user"]["displayName"]
# if its not a duplicate, add it to the list
if message_sender not in [
member.display_name for member in post_members_list
]:
post_members_list.append(
BasicExpertInfo(display_name=message_sender)
)
# if there are no found post members, grab the members from the parent channel
if not post_members_list:
post_members_list = _extract_channel_members(channel)
if not thread_text:
return None
semantic_string = _construct_semantic_identifier(channel, top_message)
post_id = top_message.properties["id"]
web_url = top_message.web_url
doc = Document(
id=post_id,
sections=[Section(link=web_url, text=thread_text)],
source=DocumentSource.TEAMS,
semantic_identifier=semantic_string,
title="", # teams threads don't really have a "title"
doc_updated_at=most_recent_message_datetime,
primary_owners=post_members_list,
metadata={},
)
return doc
class TeamsConnector(LoadConnector, PollConnector):
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
teams: list[str] = [],
) -> None:
self.batch_size = batch_size
self.graph_client: GraphClient | None = None
self.requested_team_list: list[str] = teams
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
teams_client_id = credentials["teams_client_id"]
teams_client_secret = credentials["teams_client_secret"]
teams_directory_id = credentials["teams_directory_id"]
def _acquire_token_func() -> dict[str, Any]:
"""
Acquire token via MSAL
"""
authority_url = f"https://login.microsoftonline.com/{teams_directory_id}"
app = msal.ConfidentialClientApplication(
authority=authority_url,
client_id=teams_client_id,
client_credential=teams_client_secret,
)
token = app.acquire_token_for_client(
scopes=["https://graph.microsoft.com/.default"]
)
return token
self.graph_client = GraphClient(_acquire_token_func)
return None
def _get_all_teams(self) -> list[Team]:
if self.graph_client is None:
raise ConnectorMissingCredentialError("Teams")
teams_list: list[Team] = []
teams = self.graph_client.teams.get().execute_query()
if len(self.requested_team_list) > 0:
adjusted_request_strings = [
requested_team.replace(" ", "")
for requested_team in self.requested_team_list
]
teams_list = [
team
for team in teams
if team.display_name.replace(" ", "") in adjusted_request_strings
]
else:
teams_list.extend(teams)
return teams_list
def _fetch_from_teams(
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
if self.graph_client is None:
raise ConnectorMissingCredentialError("Teams")
teams = self._get_all_teams()
channels = _get_channels_from_teams(
teams=teams,
)
# goes over channels, converts them into Document objects and then yields them in batches
doc_batch: list[Document] = []
for channel in channels:
thread_list = _get_threads_from_channel(channel, start=start, end=end)
for thread in thread_list:
converted_doc = _convert_thread_to_document(channel, thread)
if converted_doc:
doc_batch.append(converted_doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_from_teams()
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.utcfromtimestamp(start)
end_datetime = datetime.utcfromtimestamp(end)
return self._fetch_from_teams(start=start_datetime, end=end_datetime)
if __name__ == "__main__":
connector = TeamsConnector(teams=os.environ["TEAMS"].split(","))
connector.load_credentials(
{
"teams_client_id": os.environ["TEAMS_CLIENT_ID"],
"teams_client_secret": os.environ["TEAMS_CLIENT_SECRET"],
"teams_directory_id": os.environ["TEAMS_CLIENT_DIRECTORY_ID"],
}
)
document_batches = connector.load_from_state()
print(next(document_batches))

View File

@@ -29,6 +29,7 @@ from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import pdf_to_text
from danswer.file_processing.html_utils import web_html_cleanup
from danswer.utils.logger import setup_logger
from danswer.utils.sitemap import list_pages_for_site
logger = setup_logger()
@@ -145,11 +146,22 @@ def extract_urls_from_sitemap(sitemap_url: str) -> list[str]:
response.raise_for_status()
soup = BeautifulSoup(response.content, "html.parser")
return [
urls = [
_ensure_absolute_url(sitemap_url, loc_tag.text)
for loc_tag in soup.find_all("loc")
]
if len(urls) == 0 and len(soup.find_all("urlset")) == 0:
# the given url doesn't look like a sitemap, let's try to find one
urls = list_pages_for_site(sitemap_url)
if len(urls) == 0:
raise ValueError(
f"No URLs found in sitemap {sitemap_url}. Try using the 'single' or 'recursive' scraping options instead."
)
return urls
def _ensure_absolute_url(source_url: str, maybe_relative_url: str) -> str:
if not urlparse(maybe_relative_url).netloc:
@@ -214,6 +226,10 @@ class WebConnector(LoadConnector):
and converts them into documents"""
visited_links: set[str] = set()
to_visit: list[str] = self.to_visit_list
if not to_visit:
raise ValueError("No URLs to visit")
base_url = to_visit[0] # For the recursive case
doc_batch: list[Document] = []
@@ -254,7 +270,7 @@ class WebConnector(LoadConnector):
id=current_url,
sections=[Section(link=current_url, text=page_text)],
source=DocumentSource.WEB,
semantic_identifier=current_url.split(".")[-1],
semantic_identifier=current_url.split("/")[-1],
metadata={},
)
)

View File

@@ -0,0 +1,30 @@
from typing import ClassVar
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.mediawiki import wiki
class WikipediaConnector(wiki.MediaWikiConnector):
"""Connector for Wikipedia."""
document_source_type: ClassVar[DocumentSource] = DocumentSource.WIKIPEDIA
def __init__(
self,
categories: list[str],
pages: list[str],
recurse_depth: int,
connector_name: str,
language_code: str = "en",
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
super().__init__(
hostname="wikipedia.org",
categories=categories,
pages=pages,
recurse_depth=recurse_depth,
connector_name=connector_name,
language_code=language_code,
batch_size=batch_size,
)

View File

@@ -4,6 +4,7 @@ from zenpy import Zenpy # type: ignore
from zenpy.lib.api_objects.help_centre_objects import Article # type: ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
time_str_to_utc,
@@ -81,7 +82,14 @@ class ZendeskConnector(LoadConnector, PollConnector):
)
doc_batch = []
for article in articles:
if article.body is None or article.draft:
if (
article.body is None
or article.draft
or any(
label in ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
for label in article.label_names
)
):
continue
doc_batch.append(_article_to_document(article))

View File

@@ -25,6 +25,7 @@ from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
from danswer.danswerbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.icons import source_to_github_img_link
@@ -353,6 +354,22 @@ def build_quotes_block(
return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))]
def build_standard_answer_blocks(
answer_message: str,
) -> list[Block]:
generate_button_block = ButtonElement(
action_id=GENERATE_ANSWER_BUTTON_ACTION_ID,
text="Generate Full Answer",
)
answer_block = SectionBlock(text=answer_message)
return [
answer_block,
ActionsBlock(
elements=[generate_button_block],
),
]
def build_qa_response_blocks(
message_id: int | None,
answer: str | None,
@@ -457,7 +474,7 @@ def build_follow_up_resolved_blocks(
if tag_str:
tag_str += " "
group_str = " ".join([f"<!subteam^{group}>" for group in group_ids])
group_str = " ".join([f"<!subteam^{group_id}|>" for group_id in group_ids])
if group_str:
group_str += " "

View File

@@ -8,6 +8,7 @@ FOLLOWUP_BUTTON_ACTION_ID = "followup-button"
FOLLOWUP_BUTTON_RESOLVED_ACTION_ID = "followup-resolved-button"
SLACK_CHANNEL_ID = "channel_id"
VIEW_DOC_FEEDBACK_ID = "view-doc-feedback"
GENERATE_ANSWER_BUTTON_ACTION_ID = "generate-answer-button"
class FeedbackVisibility(str, Enum):

View File

@@ -1,3 +1,4 @@
import logging
from typing import Any
from typing import cast
@@ -8,6 +9,7 @@ from slack_sdk.socket_mode import SocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest
from sqlalchemy.orm import Session
from danswer.configs.constants import MessageType
from danswer.configs.constants import SearchFeedbackType
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
from danswer.connectors.slack.utils import make_slack_api_rate_limited
@@ -21,12 +23,17 @@ from danswer.danswerbot.slack.constants import VIEW_DOC_FEEDBACK_ID
from danswer.danswerbot.slack.handlers.handle_message import (
remove_scheduled_feedback_reminder,
)
from danswer.danswerbot.slack.handlers.handle_regular_answer import (
handle_regular_answer,
)
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import build_feedback_id
from danswer.danswerbot.slack.utils import decompose_action_id
from danswer.danswerbot.slack.utils import fetch_groupids_from_names
from danswer.danswerbot.slack.utils import fetch_userids_from_emails
from danswer.danswerbot.slack.utils import fetch_group_ids_from_names
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
from danswer.danswerbot.slack.utils import get_channel_name_from_id
from danswer.danswerbot.slack.utils import get_feedback_visibility
from danswer.danswerbot.slack.utils import read_slack_thread
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_sqlalchemy_engine
@@ -36,7 +43,7 @@ from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.utils.logger import setup_logger
logger_base = setup_logger()
logger = setup_logger()
def handle_doc_feedback_button(
@@ -44,7 +51,7 @@ def handle_doc_feedback_button(
client: SocketModeClient,
) -> None:
if not (actions := req.payload.get("actions")):
logger_base.error("Missing actions. Unable to build the source feedback view")
logger.error("Missing actions. Unable to build the source feedback view")
return
# Extracts the feedback_id coming from the 'source feedback' button
@@ -72,6 +79,66 @@ def handle_doc_feedback_button(
)
def handle_generate_answer_button(
req: SocketModeRequest,
client: SocketModeClient,
) -> None:
channel_id = req.payload["channel"]["id"]
channel_name = req.payload["channel"]["name"]
message_ts = req.payload["message"]["ts"]
thread_ts = req.payload["container"]["thread_ts"]
user_id = req.payload["user"]["id"]
if not thread_ts:
raise ValueError("Missing thread_ts in the payload")
thread_messages = read_slack_thread(
channel=channel_id, thread=thread_ts, client=client.web_client
)
# remove all assistant messages till we get to the last user message
# we want the new answer to be generated off of the last "question" in
# the thread
for i in range(len(thread_messages) - 1, -1, -1):
if thread_messages[i].role == MessageType.USER:
break
if thread_messages[i].role == MessageType.ASSISTANT:
thread_messages.pop(i)
# tell the user that we're working on it
# Send an ephemeral message to the user that we're generating the answer
respond_in_thread(
client=client.web_client,
channel=channel_id,
receiver_ids=[user_id],
text="I'm working on generating a full answer for you. This may take a moment...",
thread_ts=thread_ts,
)
with Session(get_sqlalchemy_engine()) as db_session:
slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session
)
handle_regular_answer(
message_info=SlackMessageInfo(
thread_messages=thread_messages,
channel_to_respond=channel_id,
msg_to_respond=cast(str, message_ts or thread_ts),
thread_to_respond=cast(str, thread_ts or message_ts),
sender=user_id or None,
bypass_filters=True,
is_bot_msg=False,
is_bot_dm=False,
),
slack_bot_config=slack_bot_config,
receiver_ids=None,
client=client.web_client,
channel=channel_id,
logger=cast(logging.Logger, logger),
feedback_reminder_id=None,
)
def handle_slack_feedback(
feedback_id: str,
feedback_type: str,
@@ -129,7 +196,7 @@ def handle_slack_feedback(
feedback=feedback,
)
else:
logger_base.error(f"Feedback type '{feedback_type}' not supported")
logger.error(f"Feedback type '{feedback_type}' not supported")
if get_feedback_visibility() == FeedbackVisibility.PRIVATE or feedback_type not in [
LIKE_BLOCK_ACTION_ID,
@@ -193,11 +260,11 @@ def handle_followup_button(
tag_names = slack_bot_config.channel_config.get("follow_up_tags")
remaining = None
if tag_names:
tag_ids, remaining = fetch_userids_from_emails(
tag_ids, remaining = fetch_user_ids_from_emails(
tag_names, client.web_client
)
if remaining:
group_ids, _ = fetch_groupids_from_names(remaining, client.web_client)
group_ids, _ = fetch_group_ids_from_names(remaining, client.web_client)
blocks = build_follow_up_resolved_blocks(tag_ids=tag_ids, group_ids=group_ids)
@@ -272,7 +339,7 @@ def handle_followup_resolved_button(
)
if not response.get("ok"):
logger_base.error("Unable to delete message for resolved")
logger.error("Unable to delete message for resolved")
if immediate:
msg_text = f"{clicker_name} has marked this question as resolved!"

View File

@@ -1,89 +1,34 @@
import datetime
import functools
import logging
from collections.abc import Callable
from typing import Any
from typing import cast
from typing import Optional
from typing import TypeVar
from retry import retry
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.models.blocks import DividerBlock
from sqlalchemy.orm import Session
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
from danswer.configs.danswerbot_configs import DANSWER_BOT_FEEDBACK_REMINDER
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
from danswer.configs.danswerbot_configs import DISABLE_DANSWER_BOT_FILTER_DETECT
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
from danswer.danswerbot.slack.blocks import build_documents_blocks
from danswer.danswerbot.slack.blocks import build_follow_up_block
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
from danswer.danswerbot.slack.blocks import build_sources_blocks
from danswer.danswerbot.slack.blocks import get_feedback_reminder_blocks
from danswer.danswerbot.slack.blocks import get_restate_blocks
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
from danswer.danswerbot.slack.handlers.handle_regular_answer import (
handle_regular_answer,
)
from danswer.danswerbot.slack.handlers.handle_standard_answers import (
handle_standard_answers,
)
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import ChannelIdAdapter
from danswer.danswerbot.slack.utils import fetch_userids_from_emails
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
from danswer.danswerbot.slack.utils import fetch_user_ids_from_groups
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import slack_usage_report
from danswer.danswerbot.slack.utils import SlackRateLimiter
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import Persona
from danswer.db.models import SlackBotConfig
from danswer.db.models import SlackBotResponseType
from danswer.db.persona import fetch_persona_by_id
from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona,
)
from danswer.llm.factory import get_llm_for_persona
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_max_input_tokens
from danswer.one_shot_answer.answer_question import get_search_answer
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.search.models import BaseFilters
from danswer.search.models import OptionalSearchSetting
from danswer.search.models import RetrievalDetails
from danswer.utils.logger import setup_logger
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
logger_base = setup_logger()
srl = SlackRateLimiter()
RT = TypeVar("RT") # return type
def rate_limits(
client: WebClient, channel: str, thread_ts: Optional[str]
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
def decorator(func: Callable[..., RT]) -> Callable[..., RT]:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> RT:
if not srl.is_available():
func_randid, position = srl.init_waiter()
srl.notify(client, channel, position, thread_ts)
while not srl.is_available():
srl.waiter(func_randid)
srl.acquire_slot()
return func(*args, **kwargs)
return wrapper
return decorator
def send_msg_ack_to_user(details: SlackMessageInfo, client: WebClient) -> None:
if details.is_bot_msg and details.sender:
@@ -171,17 +116,9 @@ def remove_scheduled_feedback_reminder(
def handle_message(
message_info: SlackMessageInfo,
channel_config: SlackBotConfig | None,
slack_bot_config: SlackBotConfig | None,
client: WebClient,
feedback_reminder_id: str | None,
num_retries: int = DANSWER_BOT_NUM_RETRIES,
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
disable_docs_only_answer: bool = DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER,
disable_auto_detect_filters: bool = DISABLE_DANSWER_BOT_FILTER_DETECT,
reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
disable_cot: bool = DANSWER_BOT_DISABLE_COT,
thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE,
) -> bool:
"""Potentially respond to the user message depending on filters and if an answer was generated
@@ -198,14 +135,22 @@ def handle_message(
)
messages = message_info.thread_messages
message_ts_to_respond_to = message_info.msg_to_respond
sender_id = message_info.sender
bypass_filters = message_info.bypass_filters
is_bot_msg = message_info.is_bot_msg
is_bot_dm = message_info.is_bot_dm
action = "slack_message"
if is_bot_msg:
action = "slack_slash_message"
elif bypass_filters:
action = "slack_tag_message"
elif is_bot_dm:
action = "slack_dm_message"
slack_usage_report(action=action, sender_id=sender_id, client=client)
document_set_names: list[str] | None = None
persona = channel_config.persona if channel_config else None
persona = slack_bot_config.persona if slack_bot_config else None
prompt = None
if persona:
document_set_names = [
@@ -213,36 +158,13 @@ def handle_message(
]
prompt = persona.prompts[0] if persona.prompts else None
should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False
# figure out if we want to use citations or quotes
use_citations = (
not DANSWER_BOT_USE_QUOTES
if channel_config is None
else channel_config.response_type == SlackBotResponseType.CITATIONS
)
# List of user id to send message to, if None, send to everyone in channel
send_to: list[str] | None = None
respond_tag_only = False
respond_team_member_list = None
bypass_acl = False
if (
channel_config
and channel_config.persona
and channel_config.persona.document_sets
):
# For Slack channels, use the full document set, admin will be warned when configuring it
# with non-public document sets
bypass_acl = True
respond_member_group_list = None
channel_conf = None
if channel_config and channel_config.channel_config:
channel_conf = channel_config.channel_config
if slack_bot_config and slack_bot_config.channel_config:
channel_conf = slack_bot_config.channel_config
if not bypass_filters and "answer_filters" in channel_conf:
reflexion = "well_answered_postfilter" in channel_conf["answer_filters"]
if (
"questionmark_prefilter" in channel_conf["answer_filters"]
and "?" not in messages[-1].message
@@ -259,7 +181,7 @@ def handle_message(
)
respond_tag_only = channel_conf.get("respond_tag_only") or False
respond_team_member_list = channel_conf.get("respond_team_member_list") or None
respond_member_group_list = channel_conf.get("respond_member_group_list", None)
if respond_tag_only and not bypass_filters:
logger.info(
@@ -268,12 +190,23 @@ def handle_message(
)
return False
if respond_team_member_list:
send_to, _ = fetch_userids_from_emails(respond_team_member_list, client)
# List of user id to send message to, if None, send to everyone in channel
send_to: list[str] | None = None
missing_users: list[str] | None = None
if respond_member_group_list:
send_to, missing_ids = fetch_user_ids_from_emails(
respond_member_group_list, client
)
user_ids, missing_users = fetch_user_ids_from_groups(missing_ids, client)
send_to = list(set(send_to + user_ids)) if send_to else user_ids
if missing_users:
logger.warning(f"Failed to find these users/groups: {missing_users}")
# If configured to respond to team members only, then cannot be used with a /DanswerBot command
# which would just respond to the sender
if respond_team_member_list and is_bot_msg:
if send_to and is_bot_msg:
if sender_id:
respond_in_thread(
client=client,
@@ -288,281 +221,28 @@ def handle_message(
except SlackApiError as e:
logger.error(f"Was not able to react to user message due to: {e}")
@retry(
tries=num_retries,
delay=0.25,
backoff=2,
logger=logger,
)
@rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to)
def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse:
action = "slack_message"
if is_bot_msg:
action = "slack_slash_message"
elif bypass_filters:
action = "slack_tag_message"
elif is_bot_dm:
action = "slack_dm_message"
slack_usage_report(action=action, sender_id=sender_id, client=client)
max_document_tokens: int | None = None
max_history_tokens: int | None = None
with Session(get_sqlalchemy_engine()) as db_session:
if len(new_message_request.messages) > 1:
persona = cast(
Persona,
fetch_persona_by_id(db_session, new_message_request.persona_id),
)
llm = get_llm_for_persona(persona)
# In cases of threads, split the available tokens between docs and thread context
input_tokens = get_max_input_tokens(
model_name=llm.config.model_name,
model_provider=llm.config.model_provider,
)
max_history_tokens = int(input_tokens * thread_context_percent)
remaining_tokens = input_tokens - max_history_tokens
query_text = new_message_request.messages[0].message
if persona:
max_document_tokens = compute_max_document_tokens_for_persona(
persona=persona,
actual_user_input=query_text,
max_llm_token_override=remaining_tokens,
)
else:
max_document_tokens = (
remaining_tokens
- 512 # Needs to be more than any of the QA prompts
- check_number_of_tokens(query_text)
)
# This also handles creating the query event in postgres
answer = get_search_answer(
query_req=new_message_request,
user=None,
max_document_tokens=max_document_tokens,
max_history_tokens=max_history_tokens,
db_session=db_session,
answer_generation_timeout=answer_generation_timeout,
enable_reflexion=reflexion,
bypass_acl=bypass_acl,
use_citations=use_citations,
danswerbot_flow=True,
)
if not answer.error_msg:
return answer
else:
raise RuntimeError(answer.error_msg)
try:
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
# it allows the slack flow to extract out filters from the user query
filters = BaseFilters(
source_type=None,
document_set=document_set_names,
time_cutoff=None,
)
# Default True because no other ways to apply filters in Slack (no nice UI)
auto_detect_filters = (
persona.llm_filter_extraction if persona is not None else True
)
if disable_auto_detect_filters:
auto_detect_filters = False
retrieval_details = RetrievalDetails(
run_search=OptionalSearchSetting.ALWAYS,
real_time=False,
filters=filters,
enable_auto_detect_filters=auto_detect_filters,
)
# This includes throwing out answer via reflexion
answer = _get_answer(
DirectQARequest(
messages=messages,
prompt_id=prompt.id if prompt else None,
persona_id=persona.id if persona is not None else 0,
retrieval_options=retrieval_details,
chain_of_thought=not disable_cot,
skip_rerank=not ENABLE_RERANKING_ASYNC_FLOW,
)
)
except Exception as e:
logger.exception(
f"Unable to process message - did not successfully answer "
f"in {num_retries} attempts"
)
# Optionally, respond in thread with the error message, Used primarily
# for debugging purposes
if should_respond_with_error_msgs:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=None,
text=f"Encountered exception when trying to answer: \n\n```{e}```",
thread_ts=message_ts_to_respond_to,
)
# In case of failures, don't keep the reaction there permanently
try:
update_emote_react(
emoji=DANSWER_REACT_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=True,
client=client,
)
except SlackApiError as e:
logger.error(f"Failed to remove Reaction due to: {e}")
return True
# Got an answer at this point, can remove reaction and give results
try:
update_emote_react(
emoji=DANSWER_REACT_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=True,
with Session(get_sqlalchemy_engine()) as db_session:
# first check if we need to respond with a standard answer
used_standard_answer = handle_standard_answers(
message_info=message_info,
receiver_ids=send_to,
slack_bot_config=slack_bot_config,
prompt=prompt,
logger=logger,
client=client,
db_session=db_session,
)
except SlackApiError as e:
logger.error(f"Failed to remove Reaction due to: {e}")
if used_standard_answer:
return False
if answer.answer_valid is False:
logger.info(
"Answer was evaluated to be invalid, throwing it away without responding."
)
update_emote_react(
emoji=DANSWER_FOLLOWUP_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=False,
client=client,
)
if answer.answer:
logger.debug(answer.answer)
return True
retrieval_info = answer.docs
if not retrieval_info:
# This should not happen, even with no docs retrieved, there is still info returned
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
top_docs = retrieval_info.top_documents
if not top_docs and not should_respond_even_with_no_docs:
logger.error(
f"Unable to answer question: '{answer.rephrase}' - no documents found"
)
# Optionally, respond in thread with the error message
# Used primarily for debugging purposes
if should_respond_with_error_msgs:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=None,
text="Found no documents when trying to answer. Did you index any documents?",
thread_ts=message_ts_to_respond_to,
)
return True
if not answer.answer and disable_docs_only_answer:
logger.info(
"Unable to find answer - not responding since the "
"`DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER` env variable is set"
)
return True
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
answer_blocks = build_qa_response_blocks(
message_id=answer.chat_message_id,
answer=answer.answer,
quotes=answer.quotes.quotes if answer.quotes else None,
source_filters=retrieval_info.applied_source_filters,
time_cutoff=retrieval_info.applied_time_cutoff,
favor_recent=retrieval_info.recency_bias_multiplier > 1,
# currently Personas don't support quotes
# if citations are enabled, also don't use quotes
skip_quotes=persona is not None or use_citations,
process_message_for_citations=use_citations,
feedback_reminder_id=feedback_reminder_id,
)
# Get the chunks fed to the LLM only, then fill with other docs
llm_doc_inds = answer.llm_chunks_indices or []
llm_docs = [top_docs[i] for i in llm_doc_inds]
remaining_docs = [
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
]
priority_ordered_docs = llm_docs + remaining_docs
document_blocks = []
citations_block = []
# if citations are enabled, only show cited documents
if use_citations:
citations = answer.citations or []
cited_docs = []
for citation in citations:
matching_doc = next(
(d for d in top_docs if d.document_id == citation.document_id),
None,
)
if matching_doc:
cited_docs.append((citation.citation_num, matching_doc))
cited_docs.sort()
citations_block = build_sources_blocks(cited_documents=cited_docs)
elif priority_ordered_docs:
document_blocks = build_documents_blocks(
documents=priority_ordered_docs,
message_id=answer.chat_message_id,
)
document_blocks = [DividerBlock()] + document_blocks
all_blocks = (
restate_question_block + answer_blocks + citations_block + document_blocks
)
if channel_conf and channel_conf.get("follow_up_tags") is not None:
all_blocks.append(build_follow_up_block(message_id=answer.chat_message_id))
try:
respond_in_thread(
# if no standard answer applies, try a regular answer
issue_with_regular_answer = handle_regular_answer(
message_info=message_info,
slack_bot_config=slack_bot_config,
receiver_ids=send_to,
client=client,
channel=channel,
receiver_ids=send_to,
text="Hello! Danswer has some results for you!",
blocks=all_blocks,
thread_ts=message_ts_to_respond_to,
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
unfurl=False,
logger=logger,
feedback_reminder_id=feedback_reminder_id,
)
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
if respond_team_member_list:
respond_in_thread(
client=client,
channel=channel,
text=(
"👋 Hi, we've just gathered and forwarded the relevant "
+ "information to the team. They'll get back to you shortly!"
),
thread_ts=message_ts_to_respond_to,
)
return False
except Exception:
logger.exception(
f"Unable to process message - could not respond in slack in {num_retries} attempts"
)
return True
return issue_with_regular_answer

View File

@@ -0,0 +1,465 @@
import functools
import logging
from collections.abc import Callable
from typing import Any
from typing import cast
from typing import Optional
from typing import TypeVar
from retry import retry
from slack_sdk import WebClient
from slack_sdk.models.blocks import DividerBlock
from slack_sdk.models.blocks import SectionBlock
from sqlalchemy.orm import Session
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
from danswer.danswerbot.slack.blocks import build_documents_blocks
from danswer.danswerbot.slack.blocks import build_follow_up_block
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
from danswer.danswerbot.slack.blocks import build_sources_blocks
from danswer.danswerbot.slack.blocks import get_restate_blocks
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import SlackRateLimiter
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import Persona
from danswer.db.models import SlackBotConfig
from danswer.db.models import SlackBotResponseType
from danswer.db.persona import fetch_persona_by_id
from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona,
)
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_max_input_tokens
from danswer.one_shot_answer.answer_question import get_search_answer
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.search.enums import OptionalSearchSetting
from danswer.search.models import BaseFilters
from danswer.search.models import RetrievalDetails
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
srl = SlackRateLimiter()
RT = TypeVar("RT") # return type
def rate_limits(
client: WebClient, channel: str, thread_ts: Optional[str]
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
def decorator(func: Callable[..., RT]) -> Callable[..., RT]:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> RT:
if not srl.is_available():
func_randid, position = srl.init_waiter()
srl.notify(client, channel, position, thread_ts)
while not srl.is_available():
srl.waiter(func_randid)
srl.acquire_slot()
return func(*args, **kwargs)
return wrapper
return decorator
def handle_regular_answer(
message_info: SlackMessageInfo,
slack_bot_config: SlackBotConfig | None,
receiver_ids: list[str] | None,
client: WebClient,
channel: str,
logger: logging.Logger,
feedback_reminder_id: str | None,
num_retries: int = DANSWER_BOT_NUM_RETRIES,
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE,
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
disable_docs_only_answer: bool = DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER,
disable_cot: bool = DANSWER_BOT_DISABLE_COT,
reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
) -> bool:
channel_conf = slack_bot_config.channel_config if slack_bot_config else None
messages = message_info.thread_messages
message_ts_to_respond_to = message_info.msg_to_respond
is_bot_msg = message_info.is_bot_msg
document_set_names: list[str] | None = None
persona = slack_bot_config.persona if slack_bot_config else None
prompt = None
if persona:
document_set_names = [
document_set.name for document_set in persona.document_sets
]
prompt = persona.prompts[0] if persona.prompts else None
should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False
bypass_acl = False
if (
slack_bot_config
and slack_bot_config.persona
and slack_bot_config.persona.document_sets
):
# For Slack channels, use the full document set, admin will be warned when configuring it
# with non-public document sets
bypass_acl = True
# figure out if we want to use citations or quotes
use_citations = (
not DANSWER_BOT_USE_QUOTES
if slack_bot_config is None
else slack_bot_config.response_type == SlackBotResponseType.CITATIONS
)
if not message_ts_to_respond_to:
raise RuntimeError(
"No message timestamp to respond to in `handle_message`. This should never happen."
)
@retry(
tries=num_retries,
delay=0.25,
backoff=2,
logger=logger,
)
@rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to)
def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | None:
max_document_tokens: int | None = None
max_history_tokens: int | None = None
with Session(get_sqlalchemy_engine()) as db_session:
if len(new_message_request.messages) > 1:
persona = cast(
Persona,
fetch_persona_by_id(db_session, new_message_request.persona_id),
)
llm, _ = get_llms_for_persona(persona)
# In cases of threads, split the available tokens between docs and thread context
input_tokens = get_max_input_tokens(
model_name=llm.config.model_name,
model_provider=llm.config.model_provider,
)
max_history_tokens = int(input_tokens * thread_context_percent)
remaining_tokens = input_tokens - max_history_tokens
query_text = new_message_request.messages[0].message
if persona:
max_document_tokens = compute_max_document_tokens_for_persona(
persona=persona,
actual_user_input=query_text,
max_llm_token_override=remaining_tokens,
)
else:
max_document_tokens = (
remaining_tokens
- 512 # Needs to be more than any of the QA prompts
- check_number_of_tokens(query_text)
)
if DISABLE_GENERATIVE_AI:
return None
# This also handles creating the query event in postgres
answer = get_search_answer(
query_req=new_message_request,
user=None,
max_document_tokens=max_document_tokens,
max_history_tokens=max_history_tokens,
db_session=db_session,
answer_generation_timeout=answer_generation_timeout,
enable_reflexion=reflexion,
bypass_acl=bypass_acl,
use_citations=use_citations,
danswerbot_flow=True,
)
if not answer.error_msg:
return answer
else:
raise RuntimeError(answer.error_msg)
try:
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
# it allows the slack flow to extract out filters from the user query
filters = BaseFilters(
source_type=None,
document_set=document_set_names,
time_cutoff=None,
)
# Default True because no other ways to apply filters in Slack (no nice UI)
# Commenting this out because this is only available to the slackbot for now
# later we plan to implement this at the persona level where this will get
# commented back in
# auto_detect_filters = (
# persona.llm_filter_extraction if persona is not None else True
# )
auto_detect_filters = (
slack_bot_config.enable_auto_filters
if slack_bot_config is not None
else False
)
retrieval_details = RetrievalDetails(
run_search=OptionalSearchSetting.ALWAYS,
real_time=False,
filters=filters,
enable_auto_detect_filters=auto_detect_filters,
)
# This includes throwing out answer via reflexion
answer = _get_answer(
DirectQARequest(
messages=messages,
prompt_id=prompt.id if prompt else None,
persona_id=persona.id if persona is not None else 0,
retrieval_options=retrieval_details,
chain_of_thought=not disable_cot,
skip_rerank=not ENABLE_RERANKING_ASYNC_FLOW,
)
)
except Exception as e:
logger.exception(
f"Unable to process message - did not successfully answer "
f"in {num_retries} attempts"
)
# Optionally, respond in thread with the error message, Used primarily
# for debugging purposes
if should_respond_with_error_msgs:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=None,
text=f"Encountered exception when trying to answer: \n\n```{e}```",
thread_ts=message_ts_to_respond_to,
)
# In case of failures, don't keep the reaction there permanently
update_emote_react(
emoji=DANSWER_REACT_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=True,
client=client,
)
return True
# Edge case handling, for tracking down the Slack usage issue
if answer is None:
assert DISABLE_GENERATIVE_AI is True
try:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=receiver_ids,
text="Hello! Danswer has some results for you!",
blocks=[
SectionBlock(
text="Danswer is down for maintenance.\nWe're working hard on recharging the AI!"
)
],
thread_ts=message_ts_to_respond_to,
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
unfurl=False,
)
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
if receiver_ids:
respond_in_thread(
client=client,
channel=channel,
text=(
"👋 Hi, we've just gathered and forwarded the relevant "
+ "information to the team. They'll get back to you shortly!"
),
thread_ts=message_ts_to_respond_to,
)
return False
except Exception:
logger.exception(
f"Unable to process message - could not respond in slack in {num_retries} attempts"
)
return True
# Got an answer at this point, can remove reaction and give results
update_emote_react(
emoji=DANSWER_REACT_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=True,
client=client,
)
if answer.answer_valid is False:
logger.info(
"Answer was evaluated to be invalid, throwing it away without responding."
)
update_emote_react(
emoji=DANSWER_FOLLOWUP_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=False,
client=client,
)
if answer.answer:
logger.debug(answer.answer)
return True
retrieval_info = answer.docs
if not retrieval_info:
# This should not happen, even with no docs retrieved, there is still info returned
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
top_docs = retrieval_info.top_documents
if not top_docs and not should_respond_even_with_no_docs:
logger.error(
f"Unable to answer question: '{answer.rephrase}' - no documents found"
)
# Optionally, respond in thread with the error message
# Used primarily for debugging purposes
if should_respond_with_error_msgs:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=None,
text="Found no documents when trying to answer. Did you index any documents?",
thread_ts=message_ts_to_respond_to,
)
return True
if not answer.answer and disable_docs_only_answer:
logger.info(
"Unable to find answer - not responding since the "
"`DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER` env variable is set"
)
return True
only_respond_with_citations_or_quotes = (
channel_conf
and "well_answered_postfilter" in channel_conf.get("answer_filters", [])
)
has_citations_or_quotes = bool(answer.citations or answer.quotes)
if (
only_respond_with_citations_or_quotes
and not has_citations_or_quotes
and not message_info.bypass_filters
):
logger.error(
f"Unable to find citations or quotes to answer: '{answer.rephrase}' - not answering!"
)
# Optionally, respond in thread with the error message
# Used primarily for debugging purposes
if should_respond_with_error_msgs:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=None,
text="Found no citations or quotes when trying to answer.",
thread_ts=message_ts_to_respond_to,
)
return True
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
answer_blocks = build_qa_response_blocks(
message_id=answer.chat_message_id,
answer=answer.answer,
quotes=answer.quotes.quotes if answer.quotes else None,
source_filters=retrieval_info.applied_source_filters,
time_cutoff=retrieval_info.applied_time_cutoff,
favor_recent=retrieval_info.recency_bias_multiplier > 1,
# currently Personas don't support quotes
# if citations are enabled, also don't use quotes
skip_quotes=persona is not None or use_citations,
process_message_for_citations=use_citations,
feedback_reminder_id=feedback_reminder_id,
)
# Get the chunks fed to the LLM only, then fill with other docs
llm_doc_inds = answer.llm_chunks_indices or []
llm_docs = [top_docs[i] for i in llm_doc_inds]
remaining_docs = [
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
]
priority_ordered_docs = llm_docs + remaining_docs
document_blocks = []
citations_block = []
# if citations are enabled, only show cited documents
if use_citations:
citations = answer.citations or []
cited_docs = []
for citation in citations:
matching_doc = next(
(d for d in top_docs if d.document_id == citation.document_id),
None,
)
if matching_doc:
cited_docs.append((citation.citation_num, matching_doc))
cited_docs.sort()
citations_block = build_sources_blocks(cited_documents=cited_docs)
elif priority_ordered_docs:
document_blocks = build_documents_blocks(
documents=priority_ordered_docs,
message_id=answer.chat_message_id,
)
document_blocks = [DividerBlock()] + document_blocks
all_blocks = (
restate_question_block + answer_blocks + citations_block + document_blocks
)
if channel_conf and channel_conf.get("follow_up_tags") is not None:
all_blocks.append(build_follow_up_block(message_id=answer.chat_message_id))
try:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=receiver_ids,
text="Hello! Danswer has some results for you!",
blocks=all_blocks,
thread_ts=message_ts_to_respond_to,
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
unfurl=False,
)
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
if receiver_ids:
send_team_member_message(
client=client,
channel=channel,
thread_ts=message_ts_to_respond_to,
)
return False
except Exception:
logger.exception(
f"Unable to process message - could not respond in slack in {num_retries} attempts"
)
return True

View File

@@ -0,0 +1,216 @@
import logging
from slack_sdk import WebClient
from sqlalchemy.orm import Session
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
from danswer.danswerbot.slack.blocks import build_standard_answer_blocks
from danswer.danswerbot.slack.blocks import get_restate_blocks
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.chat import create_chat_session
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_chat_messages_by_sessions
from danswer.db.chat import get_chat_sessions_by_slack_thread_id
from danswer.db.chat import get_or_create_root_message
from danswer.db.models import Prompt
from danswer.db.models import SlackBotConfig
from danswer.db.standard_answer import fetch_standard_answer_categories_by_names
from danswer.db.standard_answer import find_matching_standard_answers
from danswer.server.manage.models import StandardAnswer
from danswer.utils.logger import setup_logger
logger = setup_logger()
def oneoff_standard_answers(
message: str,
slack_bot_categories: list[str],
db_session: Session,
) -> list[StandardAnswer]:
"""
Respond to the user message if it matches any configured standard answers.
Returns a list of matching StandardAnswers if found, otherwise None.
"""
configured_standard_answers = {
standard_answer
for category in fetch_standard_answer_categories_by_names(
slack_bot_categories, db_session=db_session
)
for standard_answer in category.standard_answers
}
matching_standard_answers = find_matching_standard_answers(
query=message,
id_in=[answer.id for answer in configured_standard_answers],
db_session=db_session,
)
server_standard_answers = [
StandardAnswer.from_model(db_answer) for db_answer in matching_standard_answers
]
return server_standard_answers
def handle_standard_answers(
message_info: SlackMessageInfo,
receiver_ids: list[str] | None,
slack_bot_config: SlackBotConfig | None,
prompt: Prompt | None,
logger: logging.Logger,
client: WebClient,
db_session: Session,
) -> bool:
"""
Potentially respond to the user message depending on whether the user's message matches
any of the configured standard answers and also whether those answers have already been
provided in the current thread.
Returns True if standard answers are found to match the user's message and therefore,
we still need to respond to the users.
"""
# if no channel config, then no standard answers are configured
if not slack_bot_config:
return False
slack_thread_id = message_info.thread_to_respond
configured_standard_answer_categories = (
slack_bot_config.standard_answer_categories if slack_bot_config else []
)
configured_standard_answers = set(
[
standard_answer
for standard_answer_category in configured_standard_answer_categories
for standard_answer in standard_answer_category.standard_answers
]
)
query_msg = message_info.thread_messages[-1]
if slack_thread_id is None:
used_standard_answer_ids = set([])
else:
chat_sessions = get_chat_sessions_by_slack_thread_id(
slack_thread_id=slack_thread_id,
user_id=None,
db_session=db_session,
)
chat_messages = get_chat_messages_by_sessions(
chat_session_ids=[chat_session.id for chat_session in chat_sessions],
user_id=None,
db_session=db_session,
skip_permission_check=True,
)
used_standard_answer_ids = set(
[
standard_answer.id
for chat_message in chat_messages
for standard_answer in chat_message.standard_answers
]
)
usable_standard_answers = configured_standard_answers.difference(
used_standard_answer_ids
)
if usable_standard_answers:
matching_standard_answers = find_matching_standard_answers(
query=query_msg.message,
id_in=[standard_answer.id for standard_answer in usable_standard_answers],
db_session=db_session,
)
else:
matching_standard_answers = []
if matching_standard_answers:
chat_session = create_chat_session(
db_session=db_session,
description="",
user_id=None,
persona_id=slack_bot_config.persona.id if slack_bot_config.persona else 0,
danswerbot_flow=True,
slack_thread_id=slack_thread_id,
one_shot=True,
)
root_message = get_or_create_root_message(
chat_session_id=chat_session.id, db_session=db_session
)
new_user_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=root_message,
prompt_id=prompt.id if prompt else None,
message=query_msg.message,
token_count=0,
message_type=MessageType.USER,
db_session=db_session,
commit=True,
)
formatted_answers = []
for standard_answer in matching_standard_answers:
block_quotified_answer = ">" + standard_answer.answer.replace("\n", "\n> ")
formatted_answer = (
f'Since you mentioned _"{standard_answer.keyword}"_, '
f"I thought this might be useful: \n\n{block_quotified_answer}"
)
formatted_answers.append(formatted_answer)
answer_message = "\n\n".join(formatted_answers)
_ = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=new_user_message,
prompt_id=prompt.id if prompt else None,
message=answer_message,
token_count=0,
message_type=MessageType.ASSISTANT,
error=None,
db_session=db_session,
commit=True,
)
update_emote_react(
emoji=DANSWER_REACT_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=True,
client=client,
)
restate_question_blocks = get_restate_blocks(
msg=query_msg.message,
is_bot_msg=message_info.is_bot_msg,
)
answer_blocks = build_standard_answer_blocks(
answer_message=answer_message,
)
all_blocks = restate_question_blocks + answer_blocks
try:
respond_in_thread(
client=client,
channel=message_info.channel_to_respond,
receiver_ids=receiver_ids,
text="Hello! Danswer has some results for you!",
blocks=all_blocks,
thread_ts=message_info.msg_to_respond,
unfurl=False,
)
if receiver_ids and slack_thread_id:
send_team_member_message(
client=client,
channel=message_info.channel_to_respond,
thread_ts=slack_thread_id,
)
return True
except Exception as e:
logger.exception(f"Unable to send standard answer message: {e}")
return False
else:
return False

View File

@@ -0,0 +1,19 @@
from slack_sdk import WebClient
from danswer.danswerbot.slack.utils import respond_in_thread
def send_team_member_message(
client: WebClient,
channel: str,
thread_ts: str,
) -> None:
respond_in_thread(
client=client,
channel=channel,
text=(
"👋 Hi, we've just gathered and forwarded the relevant "
+ "information to the team. They'll get back to you shortly!"
),
thread_ts=thread_ts,
)

View File

@@ -4,9 +4,9 @@ from danswer.configs.constants import DocumentSource
def source_to_github_img_link(source: DocumentSource) -> str | None:
# TODO: store these images somewhere better
if source == DocumentSource.WEB.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/Web.png"
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Web.png"
if source == DocumentSource.FILE.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/File.png"
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/File.png"
if source == DocumentSource.GOOGLE_SITES.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/GoogleSites.png"
if source == DocumentSource.SLACK.value:
@@ -20,13 +20,13 @@ def source_to_github_img_link(source: DocumentSource) -> str | None:
if source == DocumentSource.GITLAB.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Gitlab.png"
if source == DocumentSource.CONFLUENCE.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/Confluence.png"
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Confluence.png"
if source == DocumentSource.JIRA.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/Jira.png"
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Jira.png"
if source == DocumentSource.NOTION.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Notion.png"
if source == DocumentSource.ZENDESK.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/Zendesk.png"
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Zendesk.png"
if source == DocumentSource.GONG.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Gong.png"
if source == DocumentSource.LINEAR.value:
@@ -38,7 +38,7 @@ def source_to_github_img_link(source: DocumentSource) -> str | None:
if source == DocumentSource.ZULIP.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Zulip.png"
if source == DocumentSource.GURU.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/Guru.png"
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Guru.png"
if source == DocumentSource.HUBSPOT.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/HubSpot.png"
if source == DocumentSource.DOCUMENT360.value:
@@ -51,8 +51,8 @@ def source_to_github_img_link(source: DocumentSource) -> str | None:
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Sharepoint.png"
if source == DocumentSource.REQUESTTRACKER.value:
# just use file icon for now
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/File.png"
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/File.png"
if source == DocumentSource.INGESTION_API.value:
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/File.png"
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/File.png"
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/File.png"
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/File.png"

View File

@@ -10,6 +10,7 @@ from slack_sdk.socket_mode.response import SocketModeResponse
from sqlalchemy.orm import Session
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
@@ -17,6 +18,7 @@ from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
from danswer.danswerbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
@@ -26,6 +28,9 @@ from danswer.danswerbot.slack.handlers.handle_buttons import handle_followup_but
from danswer.danswerbot.slack.handlers.handle_buttons import (
handle_followup_resolved_button,
)
from danswer.danswerbot.slack.handlers.handle_buttons import (
handle_generate_answer_button,
)
from danswer.danswerbot.slack.handlers.handle_buttons import handle_slack_feedback
from danswer.danswerbot.slack.handlers.handle_message import handle_message
from danswer.danswerbot.slack.handlers.handle_message import (
@@ -40,6 +45,7 @@ from danswer.danswerbot.slack.utils import get_channel_name_from_id
from danswer.danswerbot.slack.utils import get_danswer_bot_app_id
from danswer.danswerbot.slack.utils import read_slack_thread
from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
from danswer.danswerbot.slack.utils import rephrase_slack_message
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_sqlalchemy_engine
@@ -54,6 +60,22 @@ from shared_configs.configs import MODEL_SERVER_PORT
logger = setup_logger()
# In rare cases, some users have been experiencing a massive amount of trivial messages coming through
# to the Slack Bot with trivial messages. Adding this to avoid exploding LLM costs while we track down
# the cause.
_SLACK_GREETINGS_TO_IGNORE = {
"Welcome back!",
"It's going to be a great day.",
"Salutations!",
"Greetings!",
"Feeling great!",
"Hi there",
":wave:",
}
# this is always (currently) the user id of Slack's official slackbot
_OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool:
"""True to keep going, False to ignore this Slack request"""
@@ -76,6 +98,28 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
channel_specific_logger.error("Cannot respond to empty message - skipping")
return False
if (
req.payload.setdefault("event", {}).get("user", "")
== _OFFICIAL_SLACKBOT_USER_ID
):
channel_specific_logger.info(
"Ignoring messages from Slack's official Slackbot"
)
return False
if (
msg in _SLACK_GREETINGS_TO_IGNORE
or remove_danswer_bot_tag(msg, client=client.web_client)
in _SLACK_GREETINGS_TO_IGNORE
):
channel_specific_logger.error(
f"Ignoring weird Slack greeting message: '{msg}'"
)
channel_specific_logger.error(
f"Weird Slack greeting message payload: '{req.payload}'"
)
return False
# Ensure that the message is a new message of expected type
event_type = event.get("type")
if event_type not in ["app_mention", "message"]:
@@ -157,6 +201,7 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
)
return False
logger.debug(f"Handling Slack request with Payload: '{req.payload}'")
return True
@@ -201,6 +246,14 @@ def build_request_details(
msg = remove_danswer_bot_tag(msg, client=client.web_client)
if DANSWER_BOT_REPHRASE_MESSAGE:
logger.info(f"Rephrasing Slack message. Original message: {msg}")
try:
msg = rephrase_slack_message(msg)
logger.info(f"Rephrased message: {msg}")
except Exception as e:
logger.error(f"Error while trying to rephrase the Slack message: {e}")
if tagged:
logger.info("User tagged DanswerBot")
@@ -217,6 +270,7 @@ def build_request_details(
thread_messages=thread_messages,
channel_to_respond=channel,
msg_to_respond=cast(str, message_ts or thread_ts),
thread_to_respond=cast(str, thread_ts or message_ts),
sender=event.get("user") or None,
bypass_filters=tagged,
is_bot_msg=False,
@@ -234,6 +288,7 @@ def build_request_details(
thread_messages=[single_msg],
channel_to_respond=channel,
msg_to_respond=None,
thread_to_respond=None,
sender=sender,
bypass_filters=True,
is_bot_msg=True,
@@ -303,7 +358,7 @@ def process_message(
failed = handle_message(
message_info=details,
channel_config=slack_bot_config,
slack_bot_config=slack_bot_config,
client=client.web_client,
feedback_reminder_id=feedback_reminder_id,
)
@@ -341,6 +396,8 @@ def action_routing(req: SocketModeRequest, client: SocketModeClient) -> None:
return handle_followup_resolved_button(req, client, immediate=True)
elif action["action_id"] == FOLLOWUP_BUTTON_RESOLVED_ACTION_ID:
return handle_followup_resolved_button(req, client, immediate=False)
elif action["action_id"] == GENERATE_ANSWER_BUTTON_ACTION_ID:
return handle_generate_answer_button(req, client)
def view_routing(req: SocketModeRequest, client: SocketModeClient) -> None:
@@ -412,13 +469,13 @@ if __name__ == "__main__":
# or the tokens have updated (set up for the first time)
with Session(get_sqlalchemy_engine()) as db_session:
embedding_model = get_current_db_embedding_model(db_session)
warm_up_encoders(
model_name=embedding_model.model_name,
normalize=embedding_model.normalize,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
if embedding_model.cloud_provider_id is None:
warm_up_encoders(
model_name=embedding_model.model_name,
normalize=embedding_model.normalize,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
slack_bot_tokens = latest_slack_bot_tokens
# potentially may cause a message to be dropped, but it is complicated

View File

@@ -7,6 +7,7 @@ class SlackMessageInfo(BaseModel):
thread_messages: list[ThreadMessage]
channel_to_respond: str
msg_to_respond: str | None
thread_to_respond: str | None
sender: str | None
bypass_filters: bool # User has tagged @DanswerBot
is_bot_msg: bool # User is using /DanswerBot

View File

@@ -29,7 +29,12 @@ from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
from danswer.danswerbot.slack.tokens import fetch_tokens
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.users import get_user_by_email
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llms
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import message_to_string
from danswer.one_shot_answer.models import ThreadMessage
from danswer.prompts.miscellaneous_prompts import SLACK_LANGUAGE_REPHRASE_PROMPT
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
@@ -41,6 +46,30 @@ logger = setup_logger()
DANSWER_BOT_APP_ID: str | None = None
def rephrase_slack_message(msg: str) -> str:
def _get_rephrase_message() -> list[dict[str, str]]:
messages = [
{
"role": "user",
"content": SLACK_LANGUAGE_REPHRASE_PROMPT.format(query=msg),
},
]
return messages
try:
llm, _ = get_default_llms(timeout=5)
except GenAIDisabledException:
logger.warning("Unable to rephrase Slack user message, Gen AI disabled")
return msg
messages = _get_rephrase_message()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = message_to_string(llm.invoke(filled_llm_prompt))
logger.debug(model_output)
return model_output
def update_emote_react(
emoji: str,
channel: str,
@@ -48,17 +77,25 @@ def update_emote_react(
remove: bool,
client: WebClient,
) -> None:
if not message_ts:
logger.error(f"Tried to remove a react in {channel} but no message specified")
return
try:
if not message_ts:
logger.error(
f"Tried to remove a react in {channel} but no message specified"
)
return
func = client.reactions_remove if remove else client.reactions_add
slack_call = make_slack_api_rate_limited(func) # type: ignore
slack_call(
name=emoji,
channel=channel,
timestamp=message_ts,
)
func = client.reactions_remove if remove else client.reactions_add
slack_call = make_slack_api_rate_limited(func) # type: ignore
slack_call(
name=emoji,
channel=channel,
timestamp=message_ts,
)
except SlackApiError as e:
if remove:
logger.error(f"Failed to remove Reaction due to: {e}")
else:
logger.error(f"Was not able to react to user message due to: {e}")
def get_danswer_bot_app_id(web_client: WebClient) -> Any:
@@ -107,16 +144,13 @@ def respond_in_thread(
receiver_ids: list[str] | None = None,
metadata: Metadata | None = None,
unfurl: bool = True,
) -> None:
) -> list[str]:
if not text and not blocks:
raise ValueError("One of `text` or `blocks` must be provided")
message_ids: list[str] = []
if not receiver_ids:
slack_call = make_slack_api_rate_limited(client.chat_postMessage)
else:
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
if not receiver_ids:
response = slack_call(
channel=channel,
text=text,
@@ -128,7 +162,9 @@ def respond_in_thread(
)
if not response.get("ok"):
raise RuntimeError(f"Failed to post message: {response}")
message_ids.append(response["message_ts"])
else:
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
for receiver in receiver_ids:
response = slack_call(
channel=channel,
@@ -142,6 +178,9 @@ def respond_in_thread(
)
if not response.get("ok"):
raise RuntimeError(f"Failed to post message: {response}")
message_ids.append(response["message_ts"])
return message_ids
def build_feedback_id(
@@ -263,7 +302,7 @@ def get_channel_name_from_id(
raise e
def fetch_userids_from_emails(
def fetch_user_ids_from_emails(
user_emails: list[str], client: WebClient
) -> tuple[list[str], list[str]]:
user_ids: list[str] = []
@@ -279,32 +318,72 @@ def fetch_userids_from_emails(
return user_ids, failed_to_find
def fetch_groupids_from_names(
names: list[str], client: WebClient
def fetch_user_ids_from_groups(
given_names: list[str], client: WebClient
) -> tuple[list[str], list[str]]:
group_ids: set[str] = set()
user_ids: list[str] = []
failed_to_find: list[str] = []
try:
response = client.usergroups_list()
if not isinstance(response.data, dict):
logger.error("Error fetching user groups")
return user_ids, given_names
all_group_data = response.data.get("usergroups", [])
name_id_map = {d["name"]: d["id"] for d in all_group_data}
handle_id_map = {d["handle"]: d["id"] for d in all_group_data}
for given_name in given_names:
group_id = name_id_map.get(given_name) or handle_id_map.get(
given_name.lstrip("@")
)
if not group_id:
failed_to_find.append(given_name)
continue
try:
response = client.usergroups_users_list(usergroup=group_id)
if isinstance(response.data, dict):
user_ids.extend(response.data.get("users", []))
else:
failed_to_find.append(given_name)
except Exception as e:
logger.error(f"Error fetching user group ids: {str(e)}")
failed_to_find.append(given_name)
except Exception as e:
logger.error(f"Error fetching user groups: {str(e)}")
failed_to_find = given_names
return user_ids, failed_to_find
def fetch_group_ids_from_names(
given_names: list[str], client: WebClient
) -> tuple[list[str], list[str]]:
group_data: list[str] = []
failed_to_find: list[str] = []
try:
response = client.usergroups_list()
if response.get("ok") and "usergroups" in response.data:
all_groups_dicts = response.data["usergroups"] # type: ignore
name_id_map = {d["name"]: d["id"] for d in all_groups_dicts}
handle_id_map = {d["handle"]: d["id"] for d in all_groups_dicts}
for group in names:
if group in name_id_map:
group_ids.add(name_id_map[group])
elif group in handle_id_map:
group_ids.add(handle_id_map[group])
else:
failed_to_find.append(group)
else:
# Most likely a Slack App scope issue
if not isinstance(response.data, dict):
logger.error("Error fetching user groups")
return group_data, given_names
all_group_data = response.data.get("usergroups", [])
name_id_map = {d["name"]: d["id"] for d in all_group_data}
handle_id_map = {d["handle"]: d["id"] for d in all_group_data}
for given_name in given_names:
id = handle_id_map.get(given_name.lstrip("@"))
id = id or name_id_map.get(given_name)
if id:
group_data.append(id)
else:
failed_to_find.append(given_name)
except Exception as e:
failed_to_find = given_names
logger.error(f"Error fetching user groups: {str(e)}")
return list(group_ids), failed_to_find
return group_data, failed_to_find
def fetch_user_semantic_id_from_id(

View File

@@ -1,4 +1,5 @@
from collections.abc import AsyncGenerator
from collections.abc import Callable
from typing import Any
from typing import Dict
@@ -16,6 +17,20 @@ from danswer.db.engine import get_sqlalchemy_async_engine
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
def get_default_admin_user_emails() -> list[str]:
"""Returns a list of emails who should default to Admin role.
Only used in the EE version. For MIT, just return empty list."""
get_default_admin_user_emails_fn: Callable[
[], list[str]
] = fetch_versioned_implementation_with_fallback(
"danswer.auth.users", "get_default_admin_user_emails_", lambda: []
)
return get_default_admin_user_emails_fn()
async def get_user_count() -> int:
@@ -32,7 +47,7 @@ async def get_user_count() -> int:
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase):
async def create(self, create_dict: Dict[str, Any]) -> UP:
user_count = await get_user_count()
if user_count == 0:
if user_count == 0 or create_dict["email"] in get_default_admin_user_emails():
create_dict["role"] = UserRole.ADMIN
else:
create_dict["role"] = UserRole.BASIC

View File

@@ -1,46 +1,45 @@
from collections.abc import Sequence
from functools import lru_cache
from datetime import datetime
from datetime import timedelta
from uuid import UUID
from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import not_
from sqlalchemy import nullsfirst
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from danswer.auth.schemas import UserRole
from danswer.chat.models import LLMRelevanceSummaryResponse
from danswer.configs.chat_configs import HARD_DELETE_CHATS
from danswer.configs.constants import MessageType
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import ChatMessage
from danswer.db.models import ChatMessage__SearchDoc
from danswer.db.models import ChatSession
from danswer.db.models import ChatSessionSharedStatus
from danswer.db.models import DocumentSet as DBDocumentSet
from danswer.db.models import Persona
from danswer.db.models import Persona__User
from danswer.db.models import Persona__UserGroup
from danswer.db.models import Prompt
from danswer.db.models import SearchDoc
from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.db.models import StarterMessage
from danswer.db.models import Tool
from danswer.db.models import ToolCall
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.db.pg_file_store import delete_lobj_by_name
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
from danswer.search.enums import RecencyBiasSetting
from danswer.search.models import RetrievalDocs
from danswer.search.models import SavedSearchDoc
from danswer.search.models import SearchDoc as ServerSearchDoc
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -75,17 +74,59 @@ def get_chat_session_by_id(
return chat_session
def get_chat_sessions_by_slack_thread_id(
slack_thread_id: str,
user_id: UUID | None,
db_session: Session,
) -> Sequence[ChatSession]:
stmt = select(ChatSession).where(ChatSession.slack_thread_id == slack_thread_id)
if user_id is not None:
stmt = stmt.where(
or_(ChatSession.user_id == user_id, ChatSession.user_id.is_(None))
)
return db_session.scalars(stmt).all()
def get_first_messages_for_chat_sessions(
chat_session_ids: list[int], db_session: Session
) -> dict[int, str]:
subquery = (
select(ChatMessage.chat_session_id, func.min(ChatMessage.id).label("min_id"))
.where(
and_(
ChatMessage.chat_session_id.in_(chat_session_ids),
ChatMessage.message_type == MessageType.USER, # Select USER messages
)
)
.group_by(ChatMessage.chat_session_id)
.subquery()
)
query = select(ChatMessage.chat_session_id, ChatMessage.message).join(
subquery,
(ChatMessage.chat_session_id == subquery.c.chat_session_id)
& (ChatMessage.id == subquery.c.min_id),
)
first_messages = db_session.execute(query).all()
return dict([(row.chat_session_id, row.message) for row in first_messages])
def get_chat_sessions_by_user(
user_id: UUID | None,
deleted: bool | None,
db_session: Session,
include_one_shot: bool = False,
only_one_shot: bool = False,
) -> list[ChatSession]:
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
if not include_one_shot:
if only_one_shot:
stmt = stmt.where(ChatSession.one_shot.is_(True))
else:
stmt = stmt.where(ChatSession.one_shot.is_(False))
stmt = stmt.order_by(desc(ChatSession.time_created))
if deleted is not None:
stmt = stmt.where(ChatSession.deleted == deleted)
@@ -95,15 +136,71 @@ def get_chat_sessions_by_user(
return list(chat_sessions)
def delete_search_doc_message_relationship(
message_id: int, db_session: Session
) -> None:
db_session.query(ChatMessage__SearchDoc).filter(
ChatMessage__SearchDoc.chat_message_id == message_id
).delete(synchronize_session=False)
db_session.commit()
def delete_tool_call_for_message_id(message_id: int, db_session: Session) -> None:
stmt = delete(ToolCall).where(ToolCall.message_id == message_id)
db_session.execute(stmt)
db_session.commit()
def delete_orphaned_search_docs(db_session: Session) -> None:
orphaned_docs = (
db_session.query(SearchDoc)
.outerjoin(ChatMessage__SearchDoc)
.filter(ChatMessage__SearchDoc.chat_message_id.is_(None))
.all()
)
for doc in orphaned_docs:
db_session.delete(doc)
db_session.commit()
def delete_messages_and_files_from_chat_session(
chat_session_id: int, db_session: Session
) -> None:
# Select messages older than cutoff_time with files
messages_with_files = db_session.execute(
select(ChatMessage.id, ChatMessage.files).where(
ChatMessage.chat_session_id == chat_session_id,
)
).fetchall()
for id, files in messages_with_files:
delete_tool_call_for_message_id(message_id=id, db_session=db_session)
delete_search_doc_message_relationship(message_id=id, db_session=db_session)
for file_info in files or {}:
lobj_name = file_info.get("id")
if lobj_name:
logger.info(f"Deleting file with name: {lobj_name}")
delete_lobj_by_name(lobj_name, db_session)
db_session.execute(
delete(ChatMessage).where(ChatMessage.chat_session_id == chat_session_id)
)
db_session.commit()
delete_orphaned_search_docs(db_session)
def create_chat_session(
db_session: Session,
description: str,
user_id: UUID | None,
persona_id: int | None = None,
persona_id: int,
llm_override: LLMOverride | None = None,
prompt_override: PromptOverride | None = None,
one_shot: bool = False,
danswerbot_flow: bool = False,
slack_thread_id: str | None = None,
) -> ChatSession:
chat_session = ChatSession(
user_id=user_id,
@@ -113,6 +210,7 @@ def create_chat_session(
prompt_override=prompt_override,
one_shot=one_shot,
danswerbot_flow=danswerbot_flow,
slack_thread_id=slack_thread_id,
)
db_session.add(chat_session)
@@ -151,25 +249,30 @@ def delete_chat_session(
db_session: Session,
hard_delete: bool = HARD_DELETE_CHATS,
) -> None:
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
)
if hard_delete:
stmt_messages = delete(ChatMessage).where(
ChatMessage.chat_session_id == chat_session_id
)
db_session.execute(stmt_messages)
stmt = delete(ChatSession).where(ChatSession.id == chat_session_id)
db_session.execute(stmt)
delete_messages_and_files_from_chat_session(chat_session_id, db_session)
db_session.execute(delete(ChatSession).where(ChatSession.id == chat_session_id))
else:
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
)
chat_session.deleted = True
db_session.commit()
def delete_chat_sessions_older_than(days_old: int, db_session: Session) -> None:
cutoff_time = datetime.utcnow() - timedelta(days=days_old)
old_sessions = db_session.execute(
select(ChatSession.user_id, ChatSession.id).where(
ChatSession.time_created < cutoff_time
)
).fetchall()
for user_id, session_id in old_sessions:
delete_chat_session(user_id, session_id, db_session, hard_delete=True)
def get_chat_message(
chat_message_id: int,
user_id: UUID | None,
@@ -195,11 +298,45 @@ def get_chat_message(
return chat_message
def get_chat_messages_by_sessions(
chat_session_ids: list[int],
user_id: UUID | None,
db_session: Session,
skip_permission_check: bool = False,
) -> Sequence[ChatMessage]:
if not skip_permission_check:
for chat_session_id in chat_session_ids:
get_chat_session_by_id(
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
)
stmt = (
select(ChatMessage)
.where(ChatMessage.chat_session_id.in_(chat_session_ids))
.order_by(nullsfirst(ChatMessage.parent_message))
)
return db_session.execute(stmt).scalars().all()
def get_search_docs_for_chat_message(
chat_message_id: int, db_session: Session
) -> list[SearchDoc]:
stmt = (
select(SearchDoc)
.join(
ChatMessage__SearchDoc, ChatMessage__SearchDoc.search_doc_id == SearchDoc.id
)
.where(ChatMessage__SearchDoc.chat_message_id == chat_message_id)
)
return list(db_session.scalars(stmt).all())
def get_chat_messages_by_session(
chat_session_id: int,
user_id: UUID | None,
db_session: Session,
skip_permission_check: bool = False,
prefetch_tool_calls: bool = False,
) -> list[ChatMessage]:
if not skip_permission_check:
get_chat_session_by_id(
@@ -207,12 +344,16 @@ def get_chat_messages_by_session(
)
stmt = (
select(ChatMessage).where(ChatMessage.chat_session_id == chat_session_id)
# Start with the root message which has no parent
select(ChatMessage)
.where(ChatMessage.chat_session_id == chat_session_id)
.order_by(nullsfirst(ChatMessage.parent_message))
)
result = db_session.execute(stmt).scalars().all()
if prefetch_tool_calls:
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
result = db_session.scalars(stmt).unique().all()
else:
result = db_session.scalars(stmt).all()
return list(result)
@@ -264,8 +405,10 @@ def create_new_chat_message(
rephrased_query: str | None = None,
error: str | None = None,
reference_docs: list[DBSearchDoc] | None = None,
alternate_assistant_id: int | None = None,
# Maps the citation number [n] to the DB SearchDoc
citations: dict[int, int] | None = None,
tool_calls: list[ToolCall] | None = None,
commit: bool = True,
) -> ChatMessage:
new_chat_message = ChatMessage(
@@ -279,7 +422,9 @@ def create_new_chat_message(
message_type=message_type,
citations=citations,
files=files,
tool_calls=tool_calls if tool_calls else [],
error=error,
alternate_assistant_id=alternate_assistant_id,
)
# SQL Alchemy will propagate this to update the reference_docs' foreign keys
@@ -357,396 +502,6 @@ def get_prompt_by_id(
return prompt
@lru_cache()
def get_default_prompt() -> Prompt:
with Session(get_sqlalchemy_engine()) as db_session:
stmt = select(Prompt).where(Prompt.id == 0)
result = db_session.execute(stmt)
prompt = result.scalar_one_or_none()
if prompt is None:
raise RuntimeError("Default Prompt not found")
return prompt
def get_persona_by_id(
persona_id: int,
# if user is `None` assume the user is an admin or auth is disabled
user: User | None,
db_session: Session,
include_deleted: bool = False,
) -> Persona:
stmt = select(Persona).where(Persona.id == persona_id)
# if user is an admin, they should have access to all Personas
if user is not None and user.role != UserRole.ADMIN:
stmt = stmt.where(or_(Persona.user_id == user.id, Persona.user_id.is_(None)))
if not include_deleted:
stmt = stmt.where(Persona.deleted.is_(False))
result = db_session.execute(stmt)
persona = result.scalar_one_or_none()
if persona is None:
raise ValueError(
f"Persona with ID {persona_id} does not exist or does not belong to user"
)
return persona
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> Sequence[Prompt]:
"""Unsafe, can fetch prompts from all users"""
if not prompt_ids:
return []
prompts = db_session.scalars(select(Prompt).where(Prompt.id.in_(prompt_ids))).all()
return prompts
def get_personas_by_ids(
persona_ids: list[int], db_session: Session
) -> Sequence[Persona]:
"""Unsafe, can fetch personas from all users"""
if not persona_ids:
return []
personas = db_session.scalars(
select(Persona).where(Persona.id.in_(persona_ids))
).all()
return personas
def get_prompt_by_name(
prompt_name: str, user: User | None, db_session: Session
) -> Prompt | None:
stmt = select(Prompt).where(Prompt.name == prompt_name)
# if user is not specified OR they are an admin, they should
# have access to all prompts, so this where clause is not needed
if user and user.role != UserRole.ADMIN:
stmt = stmt.where(Prompt.user_id == user.id)
result = db_session.execute(stmt).scalar_one_or_none()
return result
def get_persona_by_name(
persona_name: str, user: User | None, db_session: Session
) -> Persona | None:
"""Admins can see all, regular users can only fetch their own.
If user is None, assume the user is an admin or auth is disabled."""
stmt = select(Persona).where(Persona.name == persona_name)
if user and user.role != UserRole.ADMIN:
stmt = stmt.where(Persona.user_id == user.id)
result = db_session.execute(stmt).scalar_one_or_none()
return result
def upsert_prompt(
user: User | None,
name: str,
description: str,
system_prompt: str,
task_prompt: str,
include_citations: bool,
datetime_aware: bool,
personas: list[Persona] | None,
db_session: Session,
prompt_id: int | None = None,
default_prompt: bool = True,
commit: bool = True,
) -> Prompt:
if prompt_id is not None:
prompt = db_session.query(Prompt).filter_by(id=prompt_id).first()
else:
prompt = get_prompt_by_name(prompt_name=name, user=user, db_session=db_session)
if prompt:
if not default_prompt and prompt.default_prompt:
raise ValueError("Cannot update default prompt with non-default.")
prompt.name = name
prompt.description = description
prompt.system_prompt = system_prompt
prompt.task_prompt = task_prompt
prompt.include_citations = include_citations
prompt.datetime_aware = datetime_aware
prompt.default_prompt = default_prompt
if personas is not None:
prompt.personas.clear()
prompt.personas = personas
else:
prompt = Prompt(
id=prompt_id,
user_id=user.id if user else None,
name=name,
description=description,
system_prompt=system_prompt,
task_prompt=task_prompt,
include_citations=include_citations,
datetime_aware=datetime_aware,
default_prompt=default_prompt,
personas=personas or [],
)
db_session.add(prompt)
if commit:
db_session.commit()
else:
# Flush the session so that the Prompt has an ID
db_session.flush()
return prompt
def upsert_persona(
user: User | None,
name: str,
description: str,
num_chunks: float,
llm_relevance_filter: bool,
llm_filter_extraction: bool,
recency_bias: RecencyBiasSetting,
prompts: list[Prompt] | None,
document_sets: list[DBDocumentSet] | None,
llm_model_provider_override: str | None,
llm_model_version_override: str | None,
starter_messages: list[StarterMessage] | None,
is_public: bool,
db_session: Session,
tool_ids: list[int] | None = None,
persona_id: int | None = None,
default_persona: bool = False,
commit: bool = True,
) -> Persona:
if persona_id is not None:
persona = db_session.query(Persona).filter_by(id=persona_id).first()
else:
persona = get_persona_by_name(
persona_name=name, user=user, db_session=db_session
)
# Fetch and attach tools by IDs
tools = None
if tool_ids is not None:
tools = db_session.query(Tool).filter(Tool.id.in_(tool_ids)).all()
if not tools and tool_ids:
raise ValueError("Tools not found")
if persona:
if not default_persona and persona.default_persona:
raise ValueError("Cannot update default persona with non-default.")
persona.name = name
persona.description = description
persona.num_chunks = num_chunks
persona.llm_relevance_filter = llm_relevance_filter
persona.llm_filter_extraction = llm_filter_extraction
persona.recency_bias = recency_bias
persona.default_persona = default_persona
persona.llm_model_provider_override = llm_model_provider_override
persona.llm_model_version_override = llm_model_version_override
persona.starter_messages = starter_messages
persona.deleted = False # Un-delete if previously deleted
persona.is_public = is_public
# Do not delete any associations manually added unless
# a new updated list is provided
if document_sets is not None:
persona.document_sets.clear()
persona.document_sets = document_sets or []
if prompts is not None:
persona.prompts.clear()
persona.prompts = prompts
if tools is not None:
persona.tools = tools
else:
persona = Persona(
id=persona_id,
user_id=user.id if user else None,
is_public=is_public,
name=name,
description=description,
num_chunks=num_chunks,
llm_relevance_filter=llm_relevance_filter,
llm_filter_extraction=llm_filter_extraction,
recency_bias=recency_bias,
default_persona=default_persona,
prompts=prompts or [],
document_sets=document_sets or [],
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
starter_messages=starter_messages,
tools=tools or [],
)
db_session.add(persona)
if commit:
db_session.commit()
else:
# flush the session so that the persona has an ID
db_session.flush()
return persona
def mark_prompt_as_deleted(
prompt_id: int,
user: User | None,
db_session: Session,
) -> None:
prompt = get_prompt_by_id(prompt_id=prompt_id, user=user, db_session=db_session)
prompt.deleted = True
db_session.commit()
def mark_persona_as_deleted(
persona_id: int,
user: User | None,
db_session: Session,
) -> None:
persona = get_persona_by_id(persona_id=persona_id, user=user, db_session=db_session)
persona.deleted = True
db_session.commit()
def mark_persona_as_not_deleted(
persona_id: int,
user: User | None,
db_session: Session,
) -> None:
persona = get_persona_by_id(
persona_id=persona_id, user=user, db_session=db_session, include_deleted=True
)
if persona.deleted:
persona.deleted = False
db_session.commit()
else:
raise ValueError(f"Persona with ID {persona_id} is not deleted.")
def mark_delete_persona_by_name(
persona_name: str, db_session: Session, is_default: bool = True
) -> None:
stmt = (
update(Persona)
.where(Persona.name == persona_name, Persona.default_persona == is_default)
.values(deleted=True)
)
db_session.execute(stmt)
db_session.commit()
def delete_old_default_personas(
db_session: Session,
) -> None:
"""Note, this locks out the Summarize and Paraphrase personas for now
Need a more graceful fix later or those need to never have IDs"""
stmt = (
update(Persona)
.where(Persona.default_persona, Persona.id > 0)
.values(deleted=True, name=func.concat(Persona.name, "_old"))
)
db_session.execute(stmt)
db_session.commit()
def update_persona_visibility(
persona_id: int,
is_visible: bool,
db_session: Session,
) -> None:
persona = get_persona_by_id(persona_id=persona_id, user=None, db_session=db_session)
persona.is_visible = is_visible
db_session.commit()
def update_all_personas_display_priority(
display_priority_map: dict[int, int],
db_session: Session,
) -> None:
"""Updates the display priority of all lives Personas"""
personas = get_personas(user_id=None, db_session=db_session)
available_persona_ids = {persona.id for persona in personas}
if available_persona_ids != set(display_priority_map.keys()):
raise ValueError("Invalid persona IDs provided")
for persona in personas:
persona.display_priority = display_priority_map[persona.id]
db_session.commit()
def get_prompts(
user_id: UUID | None,
db_session: Session,
include_default: bool = True,
include_deleted: bool = False,
) -> Sequence[Prompt]:
stmt = select(Prompt).where(
or_(Prompt.user_id == user_id, Prompt.user_id.is_(None))
)
if not include_default:
stmt = stmt.where(Prompt.default_prompt.is_(False))
if not include_deleted:
stmt = stmt.where(Prompt.deleted.is_(False))
return db_session.scalars(stmt).all()
def get_personas(
# if user_id is `None` assume the user is an admin or auth is disabled
user_id: UUID | None,
db_session: Session,
include_default: bool = True,
include_slack_bot_personas: bool = False,
include_deleted: bool = False,
) -> Sequence[Persona]:
stmt = select(Persona).distinct()
if user_id is not None:
# Subquery to find all groups the user belongs to
user_groups_subquery = (
select(User__UserGroup.user_group_id)
.where(User__UserGroup.user_id == user_id)
.subquery()
)
# Include personas where the user is directly related or part of a user group that has access
access_conditions = or_(
Persona.is_public == True, # noqa: E712
Persona.id.in_( # User has access through list of users with access
select(Persona__User.persona_id).where(Persona__User.user_id == user_id)
),
Persona.id.in_( # User is part of a group that has access
select(Persona__UserGroup.persona_id).where(
Persona__UserGroup.user_group_id.in_(user_groups_subquery) # type: ignore
)
),
)
stmt = stmt.where(access_conditions)
if not include_default:
stmt = stmt.where(Persona.default_persona.is_(False))
if not include_slack_bot_personas:
stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX)))
if not include_deleted:
stmt = stmt.where(Persona.deleted.is_(False))
return db_session.scalars(stmt).all()
def get_doc_query_identifiers_from_model(
search_doc_ids: list[int],
chat_session: ChatSession,
@@ -764,16 +519,46 @@ def get_doc_query_identifiers_from_model(
)
raise ValueError("Docs references do not belong to user")
if any(
[doc.chat_messages[0].chat_session_id != chat_session.id for doc in search_docs]
):
raise ValueError("Invalid reference doc, not from this chat session.")
try:
if any(
[
doc.chat_messages[0].chat_session_id != chat_session.id
for doc in search_docs
]
):
raise ValueError("Invalid reference doc, not from this chat session.")
except IndexError:
# This happens when the doc has no chat_messages associated with it.
# which happens as an edge case where the chat message failed to save
# This usually happens when the LLM fails either immediately or partially through.
raise RuntimeError("Chat session failed, please start a new session.")
doc_query_identifiers = [(doc.document_id, doc.chunk_ind) for doc in search_docs]
return doc_query_identifiers
def update_search_docs_table_with_relevance(
db_session: Session,
reference_db_search_docs: list[SearchDoc],
relevance_summary: LLMRelevanceSummaryResponse,
) -> None:
for search_doc in reference_db_search_docs:
relevance_data = relevance_summary.relevance_summaries.get(
f"{search_doc.document_id}-{search_doc.chunk_ind}"
)
if relevance_data is not None:
db_session.execute(
update(SearchDoc)
.where(SearchDoc.id == search_doc.id)
.values(
is_relevant=relevance_data.relevant,
relevance_explanation=relevance_data.content,
)
)
db_session.commit()
def create_db_search_doc(
server_search_doc: ServerSearchDoc,
db_session: Session,
@@ -788,17 +573,19 @@ def create_db_search_doc(
boost=server_search_doc.boost,
hidden=server_search_doc.hidden,
doc_metadata=server_search_doc.metadata,
is_relevant=server_search_doc.is_relevant,
relevance_explanation=server_search_doc.relevance_explanation,
# For docs further down that aren't reranked, we can't use the retrieval score
score=server_search_doc.score or 0.0,
match_highlights=server_search_doc.match_highlights,
updated_at=server_search_doc.updated_at,
primary_owners=server_search_doc.primary_owners,
secondary_owners=server_search_doc.secondary_owners,
is_internet=server_search_doc.is_internet,
)
db_session.add(db_search_doc)
db_session.commit()
return db_search_doc
@@ -824,14 +611,17 @@ def translate_db_search_doc_to_server_search_doc(
hidden=db_search_doc.hidden,
metadata=db_search_doc.doc_metadata if not remove_doc_content else {},
score=db_search_doc.score,
match_highlights=db_search_doc.match_highlights
if not remove_doc_content
else [],
match_highlights=(
db_search_doc.match_highlights if not remove_doc_content else []
),
relevance_explanation=db_search_doc.relevance_explanation,
is_relevant=db_search_doc.is_relevant,
updated_at=db_search_doc.updated_at if not remove_doc_content else None,
primary_owners=db_search_doc.primary_owners if not remove_doc_content else [],
secondary_owners=db_search_doc.secondary_owners
if not remove_doc_content
else [],
secondary_owners=(
db_search_doc.secondary_owners if not remove_doc_content else []
),
is_internet=db_search_doc.is_internet,
)
@@ -849,9 +639,11 @@ def get_retrieval_docs_from_chat_message(
def translate_db_message_to_chat_message_detail(
chat_message: ChatMessage, remove_doc_content: bool = False
chat_message: ChatMessage,
remove_doc_content: bool = False,
) -> ChatMessageDetail:
chat_msg_detail = ChatMessageDetail(
chat_session_id=chat_message.chat_session_id,
message_id=chat_message.id,
parent_message=chat_message.parent_message,
latest_child_message=chat_message.latest_child_message,
@@ -864,18 +656,15 @@ def translate_db_message_to_chat_message_detail(
time_sent=chat_message.time_sent,
citations=chat_message.citations,
files=chat_message.files or [],
tool_calls=[
ToolCallFinalResult(
tool_name=tool_call.tool_name,
tool_args=tool_call.tool_arguments,
tool_result=tool_call.tool_result,
)
for tool_call in chat_message.tool_calls
],
alternate_assistant_id=chat_message.alternate_assistant_id,
)
return chat_msg_detail
def delete_persona_by_name(
persona_name: str, db_session: Session, is_default: bool = True
) -> None:
stmt = delete(Persona).where(
Persona.name == persona_name, Persona.default_persona == is_default
)
db_session.execute(stmt)
db_session.commit()

View File

@@ -7,6 +7,7 @@ from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import Session
from danswer.configs.app_configs import DEFAULT_PRUNING_FREQ
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import InputType
from danswer.db.models import Connector
@@ -84,6 +85,9 @@ def create_connector(
input_type=connector_data.input_type,
connector_specific_config=connector_data.connector_specific_config,
refresh_freq=connector_data.refresh_freq,
prune_freq=connector_data.prune_freq
if connector_data.prune_freq is not None
else DEFAULT_PRUNING_FREQ,
disabled=connector_data.disabled,
)
db_session.add(connector)
@@ -113,6 +117,11 @@ def update_connector(
connector.input_type = connector_data.input_type
connector.connector_specific_config = connector_data.connector_specific_config
connector.refresh_freq = connector_data.refresh_freq
connector.prune_freq = (
connector_data.prune_freq
if connector_data.prune_freq is not None
else DEFAULT_PRUNING_FREQ
)
connector.disabled = connector_data.disabled
db_session.commit()
@@ -259,6 +268,7 @@ def create_initial_default_connector(db_session: Session) -> None:
input_type=InputType.LOAD_STATE,
connector_specific_config={},
refresh_freq=None,
prune_freq=None,
)
db_session.add(connector)
db_session.commit()

View File

@@ -151,7 +151,8 @@ def add_credential_to_connector(
connector_id: int,
credential_id: int,
cc_pair_name: str | None,
user: User,
is_public: bool,
user: User | None,
db_session: Session,
) -> StatusResponse[int]:
connector = fetch_connector_by_id(connector_id, db_session)
@@ -185,6 +186,7 @@ def add_credential_to_connector(
connector_id=connector_id,
credential_id=credential_id,
name=cc_pair_name,
is_public=is_public,
)
db_session.add(association)
db_session.commit()
@@ -199,7 +201,7 @@ def add_credential_to_connector(
def remove_credential_from_connector(
connector_id: int,
credential_id: int,
user: User,
user: User | None,
db_session: Session,
) -> StatusResponse[int]:
connector = fetch_connector_by_id(connector_id, db_session)

Some files were not shown because too many files have changed in this diff Show More