Compare commits

..

154 Commits

Author SHA1 Message Date
pablodanswer
8d66bdd061 functional notifications 2024-10-21 17:49:13 -07:00
pablodanswer
8f67f1715c minor typing 2024-10-20 14:48:19 -07:00
pablodanswer
3b365509e2 k 2024-10-20 14:41:12 -07:00
pablodanswer
022cbdfccf robustified cloud auth type 2024-10-20 14:28:22 -07:00
pablodanswer
ebec6f6b10 k 2024-10-20 13:43:08 -07:00
pablodanswer
1cad9c7b3d add cloud auth type 2024-10-20 13:43:08 -07:00
pablodanswer
b4e975013c k 2024-10-20 13:42:38 -07:00
pablodanswer
dd26f92206 nit 2024-10-20 13:41:41 -07:00
pablodanswer
4d00ec45ad remove comments + notice logs 2024-10-20 13:34:13 -07:00
pablodanswer
1a81c67a67 k 2024-10-20 13:22:00 -07:00
pablodanswer
04f965e656 k 2024-10-20 11:52:24 -07:00
pablodanswer
277d37e0ee fix 2024-10-20 11:45:00 -07:00
pablodanswer
3cd260131b k 2024-10-20 10:16:19 -07:00
pablodanswer
ad21ee0e9a fix mysterious syncing issue! 2024-10-19 19:26:57 -07:00
pablodanswer
c7dc0e9af0 k 2024-10-19 19:15:55 -07:00
pablodanswer
75c5de802b ensure tenant id passed 2024-10-19 19:15:55 -07:00
pablodanswer
c39f590d0d k 2024-10-19 19:15:55 -07:00
pablodanswer
82a9fda846 add types 2024-10-19 19:15:55 -07:00
pablodanswer
842d4ab2a8 k 2024-10-19 19:15:55 -07:00
pablodanswer
cddcec4ea4 k 2024-10-19 19:15:55 -07:00
pablodanswer
09dd7b424c validated workaround for flush + reset 2024-10-19 19:15:55 -07:00
pablodanswer
a2fd8d5e0a add some more multi tenancy 2024-10-19 19:15:55 -07:00
pablodanswer
802dc00f78 k 2024-10-19 19:15:55 -07:00
pablodanswer
f745ca1e03 ensure **all** sharp-related packages installed (#2855) 2024-10-19 23:44:11 +00:00
pablodanswer
eaaa135f90 push vespa managed service configs (#2857)
* push vespa managed service configs

* organize

* k

* k

* k

* nit

* k

* minor cleanup

* ensure no unnecessary timeout
2024-10-19 23:43:26 +00:00
rkuo-danswer
457e7992a4 missing tenant_id as optional param (#2851)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-10-19 21:10:42 +00:00
pablodanswer
2fb1d06fbf update google sites + formik (#2834)
* update google sites + formik

* nit

* k
2024-10-19 21:03:04 +00:00
pablodanswer
8f9d4335ce (minor) search memoization + context (#2732)
* add markdown blocks to search

* nit

* k
2024-10-19 19:13:21 +00:00
pablodanswer
ee1cb084ac modify default (#2856) 2024-10-19 19:12:42 +00:00
pablodanswer
2c77ad2aab Add errors to search (#2854)
* minor - add errors to search

* k
2024-10-19 19:11:46 +00:00
pablodanswer
f7d77a3c76 Empty embedding fix (#2853)
* account for malformed urls

* fix

* k
2024-10-19 17:55:39 +00:00
pablodanswer
8b220d2dba Add assistant notifications + update assistant context (#2816)
* add assistant notifications

* nit

* update context

* validated

* ensure context passed properly

* validated + cleaned

* nit: naming

* k

* k

* final validation + new ui

* nit + video

* nit

* nit

* nit

* k

* fix typos
2024-10-19 01:21:11 +00:00
rkuo-danswer
6913efef90 fresh indexing feature branch (#2790)
* fresh indexing feature branch

* cherry pick test

* Revert "cherry pick test"

This reverts commit 2a62422068.

* set multitenant so that vespa fields match when indexing

* cleanup pass

* mypy

* pass through env var to control celery indexing concurrency

* comments on task kickoff and some logging improvements

* use get_session_with_tenant

* comment out all of update.py

* rename to RedisConnectorIndexingFenceData

* first check num_indexing_workers

* refactor RedisConnectorIndexingFenceData

* comment out on_worker_process_init

* fix where num_indexing_workers falls back

* remove extra brace
2024-10-18 22:40:05 +00:00
rkuo-danswer
12cbbe6cee use with for update instead of serializable (#2848)
* use with for update instead of serializable

* remove tenant logic handled now by get_session_with_tenant

* remove usage of begin_nested ... it's not necessary
2024-10-18 20:35:23 +00:00
hagen-danswer
55de519364 Cleanup connector form (#2849)
* move "advanced options" to the bottom of the form and cleanup curator frontend

* troll
2024-10-18 18:44:24 +00:00
Chris Weaver
36134021c5 Refactor + add global timeout env variable (#2844)
* Refactor + add global timeout env variable

* remove model

* mypy

* Remove unused
2024-10-18 18:25:27 +00:00
rkuo-danswer
5b78299880 use native rate limiting in the confluence client (#2837)
* use native rate limiting in the confluence client

* upgrade urllib3 to v2.2.3 to support retries in confluence client

* improve logging so that progress is visible.
2024-10-18 18:15:43 +00:00
Richard Kuo (Danswer)
59364aadd7 Revert "no serializable, use with_for_update to lock the row."
This reverts commit e12785d277.
2024-10-18 11:10:09 -07:00
Richard Kuo (Danswer)
e12785d277 no serializable, use with_for_update to lock the row. 2024-10-18 11:07:54 -07:00
pablodanswer
7906d9edc8 Add all-tenants migration for K8 job (#2846)
* add migration

* update migration logic for tenants

* k

* k

* k

* k
2024-10-18 02:55:05 +00:00
pablodanswer
6e54c97326 multitenant setup (#2845) 2024-10-17 17:54:02 -07:00
pablodanswer
61424de531 add sentry (#2786)
* add sentry

* nit

* nit

* add requirement to ee

* try to ensure sentry is installed in integration tests
2024-10-17 23:20:37 +00:00
rkuo-danswer
4c2cf8b132 always finalize the serialized transaction so that it doesn't leak ou… (#2843)
* always finalize the serialized transaction so that it doesn't leak outside the function

* re-raise the exception and log it
2024-10-17 23:13:57 +00:00
pablodanswer
b169f78699 Push multi tenancy for slackbot (#2828)
* push multi tenancy for slackbot

* move to utils

* k

* k

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
2024-10-17 21:04:48 +00:00
pablodanswer
e48086b1c2 add slack markdown formatting (#2829)
* add slack markdown formatting

* nit

* k
2024-10-17 20:27:57 +00:00
hagen-danswer
6b8ecb3a4b Merge pull request #2838 from danswer-ai/dont-fail-flaky
dont fail flaky tests
2024-10-17 13:38:32 -07:00
hagen-danswer
deb66a88aa dont fail flaky tests 2024-10-17 13:37:50 -07:00
hagen-danswer
90bd535c48 Merge pull request #2836 from danswer-ai/flakey-test-run-but-dont-fail
Make flakey test still run but not fail CI
2024-10-17 13:31:00 -07:00
rkuo-danswer
0de487064a lock to avoid rare serializable errors (#2818)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-10-17 19:44:51 +00:00
rkuo-danswer
114326d11a fix sync to use update_single (#2822) 2024-10-17 19:43:34 +00:00
rkuo-danswer
389c7b72db Bugfix/monitor exceptions (#2830)
* do a rollback before more db work

* warn if not all doc_by_cc_pair entries were deleted

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-10-17 19:43:19 +00:00
hagen-danswer
28ad01a51a py 2024-10-17 12:37:34 -07:00
hagen-danswer
0c102ebb5c simplified the document search function 2024-10-17 12:13:42 -07:00
hagen-danswer
5063b944ec Make flakey test still run but not fail CI 2024-10-17 11:36:59 -07:00
pablodanswer
15afe4dc78 bump litellm (#2827) 2024-10-17 18:05:35 +00:00
pablodanswer
a159779d39 prevent alembic from configuring logger (#2826)
* k

* k
2024-10-17 16:31:17 +00:00
hagen-danswer
44ebe3ae31 Merge pull request #2833 from danswer-ai/con-perm-sync-fix
Added logging for when a member has no email or username
2024-10-17 09:03:33 -07:00
hagen-danswer
938a65628d rearrange logging 2024-10-17 09:01:51 -07:00
hagen-danswer
5d390b65eb Added logging for when a member has no email or username 2024-10-17 08:47:46 -07:00
Chris Weaver
33974fc12c Add support for passthrough auth for custom tool calls (#2824)
* Add support for passthrough auth for custom tool calls

* Fix formatting
2024-10-16 22:50:16 +00:00
pablodanswer
db0779dd02 Session id: int -> UUID (#2814)
* session id: int -> UUID

* nit

* validated

* validated downgrade + upgrade + all functionality

* nit

* minor nit

* fix test case
2024-10-16 22:18:45 +00:00
pablodanswer
f3fb7c572e ensure assistant response parsed correctly (#2823) 2024-10-16 20:21:04 +00:00
rkuo-danswer
0a0215ceee check last_pruned instead of is_pruning (#2748)
* check last_pruned instead of is_pruning

* try using the ThreadingHTTPServer class for stability and avoiding blocking single-threaded behavior

* add startup delay to web server in test

* just explicitly return None if we can't parse the datetime

* switch to uvicorn for test stability
2024-10-16 18:52:27 +00:00
pablodanswer
1a9921f63e Redirect with query param (#2811)
* validated

* k

* k

* k

* minor update
2024-10-16 17:26:44 +00:00
pablodanswer
a385234c0e Parsing (#2734)
* k

* update chunking limits

* nit

* nit

* clean up types

* nit

* validate

* k
2024-10-16 16:44:19 +00:00
pablodanswer
65573210f1 add llama 3.2 (#2812) 2024-10-16 09:00:32 -07:00
Yuhong Sun
c148fa5bfa Notion Recurse Empty Final Field (#2819) 2024-10-15 23:03:37 -07:00
pablodanswer
11372aac8f Add custom tool headers (#2773)
* add custom tool headers

* simplify

* k

* k

* k

* nit
2024-10-16 04:37:00 +00:00
Yuhong Sun
f23a89ccfd Notion Empty Property Fix (#2817) 2024-10-15 21:52:00 -07:00
pablodanswer
e022e77b6d Simpler azure embedding (#2751)
* functional but janky

* nit

* adapt for azure

* nit

* minor updates

* nits

* nit

* nit

* ensure access to litellm

* k
2024-10-15 23:23:11 +00:00
pablodanswer
02cc211e91 improved code block copying (#2802)
* improved code block copying

* k
2024-10-15 23:22:40 +00:00
pablodanswer
bfe963988e various multi tenant improvements (#2803)
* various multi tenant improvements

* nit

* ensure consistent db session operations

* minor robustification
2024-10-15 20:10:57 +00:00
pablodanswer
0e6c2f0b51 add ca option (#2774) 2024-10-15 19:23:04 +00:00
pablodanswer
98e88e2715 ensure shared chats are shared (#2801)
* ensure shared chats are shared

* k

* k

* nit

* k
2024-10-15 17:26:01 +00:00
pablodanswer
da46f61123 Ensure regenerate has dropdown too (#2797)
* ensure regenerate has dropdown too

* ensure applied to all

* nit
2024-10-15 17:09:13 +00:00
rkuo-danswer
aa5be37f97 fix index attempt refreshing automatically (#2791)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-10-15 02:59:33 +00:00
rkuo-danswer
efe2e79f27 Rate limiting confluence through redis (#2798)
* try rate limiting through redis

* fix circular import issue

* fix bad formatting of family string

* Revert "fix bad formatting of family string"

This reverts commit be688899e5.

* redis usage optional

* disable test that doesn't match with new design
2024-10-14 23:51:24 +00:00
pablodanswer
6f9740d026 Ensure warmup occurs once (#2777)
* ensure shared chats are shared

* ensure warm occurs once

* nit

* ensure warmup occurs once

* Revert "ensure shared chats are shared"

This reverts commit 8be887f3ee.
2024-10-14 22:53:02 +00:00
rkuo-danswer
dee197570d Bugfix/mediawiki (#2800)
* fix formatting

* fix poorly structured doc id, fix empty page id, fix family_class_dispatch invalid name (no spaces), fix setting id with int pageid

* fix mediawiki test
2024-10-14 22:48:06 +00:00
Chris Weaver
f8a7749b46 Fix file too large error (#2799)
* Fix file too large error

* Add cannotDownloadFile
2024-10-14 14:47:36 -07:00
hagen-danswer
494fda906d Confluence permission sync fix for server deployment (#2784)
* initial commit

* Made perm sync with with cql

* filter fix

* undo connector changes

* fixed everything

* whoops
2024-10-14 20:52:57 +00:00
pablodanswer
89eaa8bc30 nit (#2795) 2024-10-14 11:39:41 -07:00
Weves
9537a2581e Handle 'cannotExportFile' + fix forms 2024-10-14 09:49:54 -07:00
Weves
3ccd951307 Fix stopping of indexing runs when pausing a connector 2024-10-14 09:47:53 -07:00
Yuhong Sun
ba712d447d Notion Connector Improvements (#2789) 2024-10-13 23:01:17 -07:00
pablodanswer
a9bcc89a2c Add cursor to cql confluence (#2775)
* add cursor to cql confluence

* k

* k

* fixed space indexing issue

* fixed .get

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
2024-10-14 02:09:17 +00:00
pablodanswer
ded42e2036 nit (#2787) 2024-10-14 01:22:14 +00:00
OMKAR MAKHARE
86ecf8e0fc Update README.md
- Corrected misspelling of Noteable to Notable.
2024-10-13 14:53:26 -07:00
Yuhong Sun
b393af676c Mypy (#2785) 2024-10-13 14:35:56 -07:00
Chris Weaver
26bdb41e8f Fix parallel tool calls (#2779)
* Fix parallel tool calls

* remove comments
2024-10-13 03:29:18 +00:00
Weves
3365e0b16e Fix tag background 2024-10-12 19:18:17 -07:00
pablodanswer
40dc4708d2 slightly cleaner loading (#2776) 2024-10-13 01:44:28 +00:00
pablodanswer
20df20ae51 Multi tenant vespa (#2762)
* add vespa multi tenancy

* k

* formatting

* Billing (#2667)

* k

* data -> control

* nit

* nit: error handling

* auth + app

* nit: color standardization

* nit

* nit: typing

* k

* k

* feat: functional upgrading

* feat: add block for downgrading to seats < active users

* add auth

* remove accomplished todo + prints

* nit

* tiny nit

* nit: centralize security

* add tenant expulsion/gating + invite user -> increment billing seat no.

* add cloud configs

* k

* k

* nit: update

* k

* k

* k

* k

* nit
2024-10-12 23:53:11 +00:00
rkuo-danswer
7eafdae17f update several github actions to silence github deprecation warnings (#2730) 2024-10-12 23:40:20 +00:00
pablodanswer
301032f59e k (#2772) 2024-10-12 03:10:09 +00:00
pablodanswer
b75b8334a6 k (#2771) 2024-10-12 03:04:48 +00:00
pablodanswer
d25de6e1cb fix web connector (#2769) 2024-10-12 00:41:15 +00:00
pablodanswer
d892203821 fix typo (#2768) 2024-10-12 00:39:09 +00:00
Chris Weaver
35d32ea3b0 Fix indexing model server port for warmup (#2767) 2024-10-11 04:24:34 +00:00
pablodanswer
1581d35476 account for no visible assistants (#2765) 2024-10-10 19:34:30 +00:00
hagen-danswer
1f4fe42f4b Add cql support for confluence connector (#2679)
* Added CQL support for Confluence

* changed string substitutions for CQL

* final cleanup

* updated string fixes

* remove print statements

* Update description
2024-10-10 19:16:56 +00:00
hagen-danswer
101b010c5c Improved logging and added comments (#2763)
* Improved logging and added comments

* fix exception logging

* cleanup
2024-10-10 17:37:27 +00:00
Yuhong Sun
b212b228fb Typo Fix (#2766) 2024-10-10 10:21:30 -07:00
Yuhong Sun
85d5e6c02f PDF Encrypted Case (#2764) 2024-10-10 10:17:18 -07:00
pablodanswer
f40c5ca9bd Add tenant context (#2596)
* add proper tenant context to background tasks

* update for new session logic

* remove unnecessary functions

* add additional tenant context

* update ports

* proper format / directory structure

* update ports

* ensure tenant context properly passed to ee bg tasks

* add user provisioning

* nit

* validated for multi tenant

* auth

* nit

* nit

* nit

* nit

* validate pruning

* evaluate integration tests

* at long last, validated celery beat

* nit: minor edge case patched

* minor

* validate update

* nit
2024-10-10 16:34:32 +00:00
Chris Weaver
9be54a2b4c Fix slack bot follow up questions (#2756) 2024-10-10 02:08:09 +00:00
pablodanswer
b4417fabd7 ensure shared assistants accessible via query params (#2740) 2024-10-10 01:47:38 +00:00
rkuo-danswer
2d74d44538 update indexing and slack bot to use stdout options (#2752) 2024-10-10 00:31:54 +00:00
pablodanswer
30d17ef9ee Convert images to jpeg (#2737)
* convert to jpeg

* k

* typing
2024-10-09 21:56:45 +00:00
hagen-danswer
804de3248e google drive permission sync cleanup (#2749) 2024-10-09 21:17:22 +00:00
rkuo-danswer
1cbc067483 print various celery queue lengths (#2729)
* print various celery queue lengths

* use the correct redis client

* mypy ignore
2024-10-09 20:37:34 +00:00
pablodanswer
6c0a0b6454 Add sync status (#2743)
* add sync status

* nit
2024-10-09 19:52:34 +00:00
Richard Kuo (Danswer)
ca88100f38 add branching 2024-10-09 12:27:28 -07:00
Richard Kuo (Danswer)
7c9f605a99 fix pr merge command 2024-10-09 11:44:47 -07:00
Richard Kuo (Danswer)
fbf09c7859 try to update token permissions 2024-10-09 11:07:20 -07:00
Richard Kuo (Danswer)
28fe0d12ca try capturing gh output and parsing 2024-10-09 10:55:47 -07:00
Richard Kuo (Danswer)
d403840507 fix where GH_TOKEN is set 2024-10-09 10:40:33 -07:00
Richard Kuo (Danswer)
174dabf52f edit step name 2024-10-09 10:38:59 -07:00
Richard Kuo (Danswer)
03807688e6 gh cli needs its token 2024-10-09 10:35:07 -07:00
Richard Kuo (Danswer)
8bbf5053de add deploy key 2024-10-09 10:29:41 -07:00
Richard Kuo (Danswer)
d6b4c08d24 need git user 2024-10-09 10:21:31 -07:00
Richard Kuo (Danswer)
af8e361fc2 handle merge commits during cherry picking 2024-10-09 10:16:36 -07:00
rkuo-danswer
7ce276bbe1 Merge pull request #2738 from danswer-ai/bugfix/hotfix-workflow-3
more hotfix workflow testing
2024-10-09 09:41:03 -07:00
Richard Kuo (Danswer)
95df136104 another cut 2024-10-09 09:40:27 -07:00
rkuo-danswer
6b57e68226 Merge pull request #2735 from danswer-ai/feature/hotfix-workflow-2
Feature/hotfix workflow 2
2024-10-08 20:38:48 -07:00
Richard Kuo (Danswer)
cbd4481838 rename 2024-10-08 20:32:39 -07:00
Richard Kuo (Danswer)
80343d6d75 update hotfix to use commas 2024-10-08 20:31:17 -07:00
pablodanswer
d5b9a6e552 add vespa + embedding timeout env variables (#2689)
* add vespa + embedding timeout env variables

* nit: integration test

* add dangerous override

* k

* add additional clarity

* nit

* nit
2024-10-09 03:20:28 +00:00
pablodanswer
10f221cd37 Remove mildly annoying groups fetch (#2733)
* remove mildly annoying groups fetch

* ensure in client component
2024-10-09 03:13:19 +00:00
pablodanswer
f83e6806b6 More robust edge detection (#2710)
* more robust edge detection

* nit

* k
2024-10-09 01:07:51 +00:00
pablodanswer
8f61505437 Fix azure (#2665)
* fix azure

* nit

* nit

* nit

* nit pretty
2024-10-08 23:13:45 +00:00
rkuo-danswer
a47d27de6c experimental workflow to auto merge hotfixes to release branches. (#2723) 2024-10-08 21:42:59 +00:00
rkuo-danswer
aa187c86e2 Merge pull request #2726 from danswer-ai/bugfix/docker-web-runners
try porting docker web build to runs-on
2024-10-08 14:42:43 -07:00
Richard Kuo (Danswer)
c72c5619f0 remove more flaky tests 2024-10-08 14:42:04 -07:00
Chris Weaver
78e7710f17 Handle bug with initial connector page display (#2727)
* Handle bug with initial connector page display

* Casing consistency
2024-10-08 21:01:37 +00:00
rkuo-danswer
672f5cc5ce urlencode the password part properly before putting it in the broker url (#2719)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-10-08 20:46:11 +00:00
rkuo-danswer
7b3c433ff8 Merge pull request #2717 from danswer-ai/bugfix/docker-legacy-key-value-format
Fix all LegacyKeyValueFormat docker warnings
2024-10-08 13:57:10 -07:00
Richard Kuo (Danswer)
057321a59f disable flaky test 2024-10-08 13:40:35 -07:00
Richard Kuo (Danswer)
5cc46341f7 try porting docker web build to runs-on 2024-10-08 13:11:59 -07:00
Chris Weaver
21a3921790 Better support for image generation capable models (#2725) 2024-10-08 12:41:14 -07:00
Richard Kuo (Danswer)
3586f9b565 experimental workflow to auto merge hotfixes to release branches. 2024-10-08 11:23:10 -07:00
Chris Weaver
aa69fe762b Temp patch to remove multiple tool calls (#2720) 2024-10-08 18:08:45 +00:00
pablodanswer
3ef72b8d1a k (#2721) 2024-10-08 09:33:29 -07:00
pablodanswer
a0124e4e50 ensure all timeout -> hook (#2718) 2024-10-08 15:48:38 +00:00
Richard Kuo (Danswer)
a52485bda2 Fix all LegacyKeyValueFormat docker warnings 2024-10-07 15:22:28 -07:00
rkuo-danswer
79d37156c6 better logging for actions being taken inside document_by_cc_pair_cleanup (#2713) 2024-10-07 22:21:16 +00:00
rkuo-danswer
6fa8fabb47 add one more retry and wait a little longer to allow ourselves to recover from infra issues (#2714) 2024-10-07 22:17:49 +00:00
pablodanswer
4214a3a6e2 Inline code + effect clarity (#2715)
* cleaner code blocks + form context

* cleaner

* nit
2024-10-07 15:23:37 -07:00
rkuo-danswer
1a3469d2c5 check before using fetch_versioned_implementation because it logs warnings that confuse users. (#2708)
Renamed get_is_ee_version to is_ee_version to be less redundant
2024-10-07 21:37:56 +00:00
rkuo-danswer
30dc408028 rely on stdout redirection for supervisord logging (#2711) 2024-10-07 21:30:03 +00:00
Yuhong Sun
5d356cc971 Remove Perm Sync Script Dev (#2712) 2024-10-07 13:50:30 -07:00
pablodanswer
e4c7cfde42 Minor update to initial modal (#2571)
* minor update

* nit: pretty
2024-10-07 20:29:04 +00:00
pablodanswer
1900a390d8 Linting (#2704)
* effect cleanup

* remove unused imports

* remove unne

* remove unnecessary packages

* k

* temp

* minor
2024-10-07 20:21:07 +00:00
pablodanswer
150dcc2883 back button + popups (#2707)
* back button + popups

* remove logs
2024-10-07 20:10:58 +00:00
377 changed files with 13257 additions and 4501 deletions

View File

@@ -32,16 +32,20 @@ inputs:
description: 'Cache destinations'
required: false
retry-wait-time:
description: 'Time to wait before retry in seconds'
description: 'Time to wait before attempt 2 in seconds'
required: false
default: '5'
default: '60'
retry-wait-time-2:
description: 'Time to wait before attempt 3 in seconds'
required: false
default: '120'
runs:
using: "composite"
steps:
- name: Build and push Docker image (First Attempt)
- name: Build and push Docker image (Attempt 1 of 3)
id: buildx1
uses: docker/build-push-action@v5
uses: docker/build-push-action@v6
continue-on-error: true
with:
context: ${{ inputs.context }}
@@ -54,16 +58,17 @@ runs:
cache-from: ${{ inputs.cache-from }}
cache-to: ${{ inputs.cache-to }}
- name: Wait to retry
- name: Wait before attempt 2
if: steps.buildx1.outcome != 'success'
run: |
echo "First attempt failed. Waiting ${{ inputs.retry-wait-time }} seconds before retry..."
sleep ${{ inputs.retry-wait-time }}
shell: bash
- name: Build and push Docker image (Retry Attempt)
- name: Build and push Docker image (Attempt 2 of 3)
id: buildx2
if: steps.buildx1.outcome != 'success'
uses: docker/build-push-action@v5
uses: docker/build-push-action@v6
with:
context: ${{ inputs.context }}
file: ${{ inputs.file }}
@@ -74,3 +79,31 @@ runs:
tags: ${{ inputs.tags }}
cache-from: ${{ inputs.cache-from }}
cache-to: ${{ inputs.cache-to }}
- name: Wait before attempt 3
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success'
run: |
echo "Second attempt failed. Waiting ${{ inputs.retry-wait-time-2 }} seconds before retry..."
sleep ${{ inputs.retry-wait-time-2 }}
shell: bash
- name: Build and push Docker image (Attempt 3 of 3)
id: buildx3
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success'
uses: docker/build-push-action@v6
with:
context: ${{ inputs.context }}
file: ${{ inputs.file }}
platforms: ${{ inputs.platforms }}
pull: ${{ inputs.pull }}
push: ${{ inputs.push }}
load: ${{ inputs.load }}
tags: ${{ inputs.tags }}
cache-from: ${{ inputs.cache-from }}
cache-to: ${{ inputs.cache-to }}
- name: Report failure
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success' && steps.buildx3.outcome != 'success'
run: |
echo "All attempts failed. Possible transient infrastucture issues? Try again later or inspect logs for details."
shell: bash

View File

@@ -11,8 +11,11 @@ env:
jobs:
build:
runs-on:
group: ${{ matrix.platform == 'linux/amd64' && 'amd64-image-builders' || 'arm64-image-builders' }}
runs-on:
- runs-on
- runner=${{ matrix.platform == 'linux/amd64' && '8cpu-linux-x64' || '8cpu-linux-arm64' }}
- run-id=${{ github.run_id }}
- tag=platform-${{ matrix.platform }}
strategy:
fail-fast: false
matrix:

View File

@@ -0,0 +1,172 @@
# This workflow is intended to be manually triggered via the GitHub Action tab.
# Given a hotfix branch, it will attempt to open a PR to all release branches and
# by default auto merge them
name: Hotfix release branches
on:
workflow_dispatch:
inputs:
hotfix_commit:
description: 'Hotfix commit hash'
required: true
hotfix_suffix:
description: 'Hotfix branch suffix (e.g. hotfix/v0.8-{suffix})'
required: true
release_branch_pattern:
description: 'Release branch pattern (regex)'
required: true
default: 'release/.*'
auto_merge:
description: 'Automatically merge the hotfix PRs'
required: true
type: choice
default: 'true'
options:
- true
- false
jobs:
hotfix_release_branches:
permissions: write-all
# See https://runs-on.com/runners/linux/
# use a lower powered instance since this just does i/o to docker hub
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
# needs RKUO_DEPLOY_KEY for write access to merge PR's
- name: Checkout Repository
uses: actions/checkout@v4
with:
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
fetch-depth: 0
- name: Set up Git user
run: |
git config user.name "Richard Kuo [bot]"
git config user.email "rkuo[bot]@danswer.ai"
- name: Fetch All Branches
run: |
git fetch --all --prune
- name: Verify Hotfix Commit Exists
run: |
git rev-parse --verify "${{ github.event.inputs.hotfix_commit }}" || { echo "Commit not found: ${{ github.event.inputs.hotfix_commit }}"; exit 1; }
- name: Get Release Branches
id: get_release_branches
run: |
BRANCHES=$(git branch -r | grep -E "${{ github.event.inputs.release_branch_pattern }}" | sed 's|origin/||' | tr -d ' ')
if [ -z "$BRANCHES" ]; then
echo "No release branches found matching pattern '${{ github.event.inputs.release_branch_pattern }}'."
exit 1
fi
echo "Found release branches:"
echo "$BRANCHES"
# Join the branches into a single line separated by commas
BRANCHES_JOINED=$(echo "$BRANCHES" | tr '\n' ',' | sed 's/,$//')
# Set the branches as an output
echo "branches=$BRANCHES_JOINED" >> $GITHUB_OUTPUT
# notes on all the vagaries of wiring up automated PR's
# https://github.com/peter-evans/create-pull-request/blob/main/docs/concepts-guidelines.md#triggering-further-workflow-runs
# we must use a custom token for GH_TOKEN to trigger the subsequent PR checks
- name: Create and Merge Pull Requests to Matching Release Branches
env:
HOTFIX_COMMIT: ${{ github.event.inputs.hotfix_commit }}
HOTFIX_SUFFIX: ${{ github.event.inputs.hotfix_suffix }}
AUTO_MERGE: ${{ github.event.inputs.auto_merge }}
GH_TOKEN: ${{ secrets.RKUO_PERSONAL_ACCESS_TOKEN }}
run: |
# Get the branches from the previous step
BRANCHES="${{ steps.get_release_branches.outputs.branches }}"
# Convert BRANCHES to an array
IFS=$',' read -ra BRANCH_ARRAY <<< "$BRANCHES"
# Loop through each release branch and create and merge a PR
for RELEASE_BRANCH in "${BRANCH_ARRAY[@]}"; do
echo "Processing $RELEASE_BRANCH..."
# Parse out the release version by removing "release/" from the branch name
RELEASE_VERSION=${RELEASE_BRANCH#release/}
echo "Release version parsed: $RELEASE_VERSION"
HOTFIX_BRANCH="hotfix/${RELEASE_VERSION}-${HOTFIX_SUFFIX}"
echo "Creating PR from $HOTFIX_BRANCH to $RELEASE_BRANCH"
# Checkout the release branch
echo "Checking out $RELEASE_BRANCH"
git checkout "$RELEASE_BRANCH"
# Create the new hotfix branch
if git rev-parse --verify "$HOTFIX_BRANCH" >/dev/null 2>&1; then
echo "Hotfix branch $HOTFIX_BRANCH already exists. Skipping branch creation."
else
echo "Branching $RELEASE_BRANCH to $HOTFIX_BRANCH"
git checkout -b "$HOTFIX_BRANCH"
fi
# Check if the hotfix commit is a merge commit
if git rev-list --merges -n 1 "$HOTFIX_COMMIT" >/dev/null 2>&1; then
# -m 1 uses the target branch as the base (which is what we want)
echo "Hotfix commit $HOTFIX_COMMIT is a merge commit, using -m 1 for cherry-pick"
CHERRY_PICK_CMD="git cherry-pick -m 1 $HOTFIX_COMMIT"
else
CHERRY_PICK_CMD="git cherry-pick $HOTFIX_COMMIT"
fi
# Perform the cherry-pick
echo "Executing: $CHERRY_PICK_CMD"
eval "$CHERRY_PICK_CMD"
if [ $? -ne 0 ]; then
echo "Cherry-pick failed for $HOTFIX_COMMIT on $HOTFIX_BRANCH. Aborting..."
git cherry-pick --abort
continue
fi
# Push the hotfix branch to the remote
echo "Pushing $HOTFIX_BRANCH..."
git push origin "$HOTFIX_BRANCH"
echo "Hotfix branch $HOTFIX_BRANCH created and pushed."
# Check if PR already exists
EXISTING_PR=$(gh pr list --head "$HOTFIX_BRANCH" --base "$RELEASE_BRANCH" --state open --json number --jq '.[0].number')
if [ -n "$EXISTING_PR" ]; then
echo "An open PR already exists: #$EXISTING_PR. Skipping..."
continue
fi
# Create a new PR and capture the output
PR_OUTPUT=$(gh pr create --title "Merge $HOTFIX_BRANCH into $RELEASE_BRANCH" \
--body "Automated PR to merge \`$HOTFIX_BRANCH\` into \`$RELEASE_BRANCH\`." \
--head "$HOTFIX_BRANCH" --base "$RELEASE_BRANCH")
# Extract the URL from the output
PR_URL=$(echo "$PR_OUTPUT" | grep -Eo 'https://github.com/[^ ]+')
echo "Pull request created: $PR_URL"
# Extract PR number from URL
PR_NUMBER=$(basename "$PR_URL")
echo "Pull request created: $PR_NUMBER"
if [ "$AUTO_MERGE" == "true" ]; then
echo "Attempting to merge pull request #$PR_NUMBER"
# Attempt to merge the PR
gh pr merge "$PR_NUMBER" --merge --auto --delete-branch
if [ $? -eq 0 ]; then
echo "Pull request #$PR_NUMBER merged successfully."
else
# Optionally, handle the error or continue
echo "Failed to merge pull request #$PR_NUMBER."
fi
fi
done

View File

@@ -167,7 +167,7 @@ jobs:
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: docker-logs
path: ${{ github.workspace }}/docker-compose.log

View File

@@ -14,10 +14,10 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: 'pip'

View File

@@ -32,7 +32,7 @@ jobs:
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: "3.11"
cache: "pip"

View File

@@ -27,7 +27,7 @@ jobs:
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: "3.11"
cache: "pip"

View File

@@ -21,7 +21,7 @@ jobs:
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: 'pip'

View File

@@ -18,6 +18,6 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.0
- uses: pre-commit/action@v3.0.1
with:
extra_args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || '' }}

View File

@@ -74,7 +74,7 @@ We also have built-in support for deployment on Kubernetes. Files for that can b
* Organizational understanding and ability to locate and suggest experts from your team.
## Other Noteable Benefits of Danswer
## Other Notable Benefits of Danswer
* User Authentication with document level access management.
* Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
* Admin Dashboard to configure connectors, document-sets, access, etc.

View File

@@ -8,10 +8,12 @@ Edition features outside of personal development or testing purposes. Please rea
founders@danswer.ai for more information. Please visit https://github.com/danswer-ai/danswer"
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
ARG DANSWER_VERSION=0.3-dev
ARG DANSWER_VERSION=0.8-dev
ENV DANSWER_VERSION=${DANSWER_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true"
ARG CA_CERT_CONTENT=""
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
# Install system dependencies
# cmake needed for psycopg (postgres)
@@ -36,6 +38,17 @@ RUN apt-get update && \
rm -rf /var/lib/apt/lists/* && \
apt-get clean
# Conditionally write the CA certificate and update certificates
RUN if [ -n "$CA_CERT_CONTENT" ]; then \
echo "Adding custom CA certificate"; \
echo "$CA_CERT_CONTENT" > /usr/local/share/ca-certificates/my-ca.crt && \
chmod 644 /usr/local/share/ca-certificates/my-ca.crt && \
update-ca-certificates; \
else \
echo "No custom CA certificate provided"; \
fi
# Install Python dependencies
# Remove py which is pulled in by retry, py is not needed and is a CVE
COPY ./requirements/default.txt /tmp/requirements.txt
@@ -92,6 +105,7 @@ COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
COPY ./danswer /app/danswer
COPY ./shared_configs /app/shared_configs
COPY ./alembic /app/alembic
COPY ./alembic_tenants /app/alembic_tenants
COPY ./alembic.ini /app/alembic.ini
COPY supervisord.conf /usr/etc/supervisord.conf
@@ -101,7 +115,7 @@ COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connect
# Put logo in assets
COPY ./assets /app/assets
ENV PYTHONPATH /app
ENV PYTHONPATH=/app
# Default command which does nothing
# This container is used by api server and background which specify their own CMD

View File

@@ -7,7 +7,7 @@ You can find it at https://hub.docker.com/r/danswer/danswer-model-server. For mo
visit https://github.com/danswer-ai/danswer."
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
ARG DANSWER_VERSION=0.3-dev
ARG DANSWER_VERSION=0.8-dev
ENV DANSWER_VERSION=${DANSWER_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true"
@@ -55,6 +55,6 @@ COPY ./shared_configs /app/shared_configs
# Model Server main code
COPY ./model_server /app/model_server
ENV PYTHONPATH /app
ENV PYTHONPATH=/app
CMD ["uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000"]

View File

@@ -1,6 +1,6 @@
# A generic, single database configuration.
[alembic]
[DEFAULT]
# path to migration scripts
script_location = alembic
@@ -47,7 +47,8 @@ prepend_sys_path = .
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
version_path_separator = os
# Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
@@ -106,3 +107,12 @@ formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
[alembic]
script_location = alembic
version_locations = %(script_location)s/versions
[schema_private]
script_location = alembic_tenants
version_locations = %(script_location)s/versions

View File

@@ -1,104 +1,99 @@
from sqlalchemy.engine.base import Connection
from typing import Any
import asyncio
from logging.config import fileConfig
import logging
from alembic import context
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.sql import text
from danswer.configs.app_configs import MULTI_TENANT
from danswer.db.engine import build_connection_string
from danswer.db.models import Base
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine
from celery.backends.database.session import ResultModelBase # type: ignore
from sqlalchemy.schema import SchemaItem
from sqlalchemy.sql import text
from danswer.background.celery.celery_app import get_all_tenant_ids
# Alembic Config object
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None and config.attributes.get(
"configure_logger", True
):
fileConfig(config.config_file_name)
# Add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
# Add your model's MetaData object here for 'autogenerate' support
target_metadata = [Base.metadata, ResultModelBase.metadata]
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
def get_schema_options() -> tuple[str, bool]:
# Set up logging
logger = logging.getLogger(__name__)
def include_object(
object: Any, name: str, type_: str, reflected: bool, compare_to: Any
) -> bool:
"""
Determines whether a database object should be included in migrations.
Excludes specified tables from migrations.
"""
if type_ == "table" and name in EXCLUDE_TABLES:
return False
return True
def get_schema_options() -> tuple[str, bool, bool]:
"""
Parses command-line options passed via '-x' in Alembic commands.
Recognizes 'schema', 'create_schema', and 'upgrade_all_tenants' options.
"""
x_args_raw = context.get_x_argument()
x_args = {}
for arg in x_args_raw:
for pair in arg.split(","):
if "=" in pair:
key, value = pair.split("=", 1)
x_args[key] = value
x_args[key.strip()] = value.strip()
schema_name = x_args.get("schema", "public")
create_schema = x_args.get("create_schema", "true").lower() == "true"
return schema_name, create_schema
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
if MULTI_TENANT and schema_name == "public":
raise ValueError(
"Cannot run default migrations in public schema when multi-tenancy is enabled. "
"Please specify a tenant-specific schema."
)
return schema_name, create_schema, upgrade_all_tenants
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
def include_object(
object: SchemaItem,
name: str,
type_: str,
reflected: bool,
compare_to: SchemaItem | None,
) -> bool:
if type_ == "table" and name in EXCLUDE_TABLES:
return False
return True
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
def do_run_migrations(
connection: Connection, schema_name: str, create_schema: bool
) -> None:
"""
url = build_connection_string()
schema, _ = get_schema_options()
Executes migrations in the specified schema.
"""
logger.info(f"About to migrate schema: {schema_name}")
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
include_object=include_object,
dialect_opts={"paramstyle": "named"},
version_table_schema=schema,
include_schemas=True,
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
schema, create_schema = get_schema_options()
if create_schema:
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"'))
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
connection.execute(text("COMMIT"))
connection.execute(text(f'SET search_path TO "{schema}"'))
# Set search_path to the target schema
connection.execute(text(f'SET search_path TO "{schema_name}"'))
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore
version_table_schema=schema,
include_object=include_object,
version_table_schema=schema_name,
include_schemas=True,
compare_type=True,
compare_server_default=True,
script_location=config.get_main_option("script_location"),
)
with context.begin_transaction():
@@ -106,20 +101,98 @@ def do_run_migrations(connection: Connection) -> None:
async def run_async_migrations() -> None:
"""Run migrations in 'online' mode."""
connectable = create_async_engine(
"""
Determines whether to run migrations for a single schema or all schemas,
and executes migrations accordingly.
"""
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
engine = create_async_engine(
build_connection_string(),
poolclass=pool.NullPool,
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
if upgrade_all_tenants:
# Run migrations for all tenant schemas sequentially
tenant_schemas = get_all_tenant_ids()
await connectable.dispose()
for schema in tenant_schemas:
try:
logger.info(f"Migrating schema: {schema}")
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
schema_name=schema,
create_schema=create_schema,
)
except Exception as e:
logger.error(f"Error migrating schema {schema}: {e}")
raise
else:
try:
logger.info(f"Migrating schema: {schema_name}")
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
schema_name=schema_name,
create_schema=create_schema,
)
except Exception as e:
logger.error(f"Error migrating schema {schema_name}: {e}")
raise
await engine.dispose()
def run_migrations_offline() -> None:
"""
Run migrations in 'offline' mode.
"""
schema_name, _, upgrade_all_tenants = get_schema_options()
url = build_connection_string()
if upgrade_all_tenants:
# Run offline migrations for all tenant schemas
engine = create_async_engine(url)
tenant_schemas = get_all_tenant_ids()
engine.sync_engine.dispose()
for schema in tenant_schemas:
logger.info(f"Migrating schema: {schema}")
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
include_object=include_object,
version_table_schema=schema,
include_schemas=True,
script_location=config.get_main_option("script_location"),
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
else:
logger.info(f"Migrating schema: {schema_name}")
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
include_object=include_object,
version_table_schema=schema_name,
include_schemas=True,
script_location=config.get_main_option("script_location"),
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
"""
Runs migrations in 'online' mode using an asynchronous engine.
"""
asyncio.run(run_async_migrations())

View File

@@ -0,0 +1,26 @@
"""add additional data to notifications
Revision ID: 1b10e1fda030
Revises: 6756efa39ada
Create Date: 2024-10-15 19:26:44.071259
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "1b10e1fda030"
down_revision = "6756efa39ada"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"notification", sa.Column("additional_data", postgresql.JSONB(), nullable=True)
)
def downgrade() -> None:
op.drop_column("notification", "additional_data")

View File

@@ -0,0 +1,30 @@
"""add api_version and deployment_name to search settings
Revision ID: 5d12a446f5c0
Revises: e4334d5b33ba
Create Date: 2024-10-08 15:56:07.975636
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "5d12a446f5c0"
down_revision = "e4334d5b33ba"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"embedding_provider", sa.Column("api_version", sa.String(), nullable=True)
)
op.add_column(
"embedding_provider", sa.Column("deployment_name", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("embedding_provider", "deployment_name")
op.drop_column("embedding_provider", "api_version")

View File

@@ -0,0 +1,153 @@
"""
Revision ID: 6756efa39ada
Revises: 5d12a446f5c0
Create Date: 2024-10-15 17:47:44.108537
"""
from alembic import op
import sqlalchemy as sa
revision = "6756efa39ada"
down_revision = "5d12a446f5c0"
branch_labels = None
depends_on = None
"""
Migrate chat_session and chat_message tables to use UUID primary keys.
This script:
1. Adds UUID columns to chat_session and chat_message
2. Populates new columns with UUIDs
3. Updates foreign key relationships
4. Removes old integer ID columns
Note: Downgrade will assign new integer IDs, not restore original ones.
"""
def upgrade() -> None:
op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto;")
op.add_column(
"chat_session",
sa.Column(
"new_id",
sa.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
),
)
op.execute("UPDATE chat_session SET new_id = gen_random_uuid();")
op.add_column(
"chat_message",
sa.Column("new_chat_session_id", sa.UUID(as_uuid=True), nullable=True),
)
op.execute(
"""
UPDATE chat_message
SET new_chat_session_id = cs.new_id
FROM chat_session cs
WHERE chat_message.chat_session_id = cs.id;
"""
)
op.drop_constraint(
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
)
op.drop_column("chat_message", "chat_session_id")
op.alter_column(
"chat_message", "new_chat_session_id", new_column_name="chat_session_id"
)
op.drop_constraint("chat_session_pkey", "chat_session", type_="primary")
op.drop_column("chat_session", "id")
op.alter_column("chat_session", "new_id", new_column_name="id")
op.create_primary_key("chat_session_pkey", "chat_session", ["id"])
op.create_foreign_key(
"chat_message_chat_session_id_fkey",
"chat_message",
"chat_session",
["chat_session_id"],
["id"],
ondelete="CASCADE",
)
def downgrade() -> None:
op.drop_constraint(
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
)
op.add_column(
"chat_session",
sa.Column("old_id", sa.Integer, autoincrement=True, nullable=True),
)
op.execute("CREATE SEQUENCE chat_session_old_id_seq OWNED BY chat_session.old_id;")
op.execute(
"ALTER TABLE chat_session ALTER COLUMN old_id SET DEFAULT nextval('chat_session_old_id_seq');"
)
op.execute(
"UPDATE chat_session SET old_id = nextval('chat_session_old_id_seq') WHERE old_id IS NULL;"
)
op.alter_column("chat_session", "old_id", nullable=False)
op.drop_constraint("chat_session_pkey", "chat_session", type_="primary")
op.create_primary_key("chat_session_pkey", "chat_session", ["old_id"])
op.add_column(
"chat_message",
sa.Column("old_chat_session_id", sa.Integer, nullable=True),
)
op.execute(
"""
UPDATE chat_message
SET old_chat_session_id = cs.old_id
FROM chat_session cs
WHERE chat_message.chat_session_id = cs.id;
"""
)
op.drop_column("chat_message", "chat_session_id")
op.alter_column(
"chat_message", "old_chat_session_id", new_column_name="chat_session_id"
)
op.create_foreign_key(
"chat_message_chat_session_id_fkey",
"chat_message",
"chat_session",
["chat_session_id"],
["old_id"],
ondelete="CASCADE",
)
op.drop_column("chat_session", "id")
op.alter_column("chat_session", "old_id", new_column_name="id")
op.alter_column(
"chat_session",
"id",
type_=sa.Integer(),
existing_type=sa.Integer(),
existing_nullable=False,
existing_server_default=False,
)
# Rename the sequence
op.execute("ALTER SEQUENCE chat_session_old_id_seq RENAME TO chat_session_id_seq;")
# Update the default value to use the renamed sequence
op.alter_column(
"chat_session",
"id",
server_default=sa.text("nextval('chat_session_id_seq'::regclass)"),
)

View File

@@ -20,7 +20,7 @@ depends_on: None = None
def upgrade() -> None:
conn = op.get_bind()
existing_ids_and_chosen_assistants = conn.execute(
sa.text("select id, chosen_assistants from public.user")
sa.text('select id, chosen_assistants from "user"')
)
op.drop_column(
"user",
@@ -37,7 +37,7 @@ def upgrade() -> None:
for id, chosen_assistants in existing_ids_and_chosen_assistants:
conn.execute(
sa.text(
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
'update "user" set chosen_assistants = :chosen_assistants where id = :id'
),
{"chosen_assistants": json.dumps(chosen_assistants), "id": id},
)
@@ -46,7 +46,7 @@ def upgrade() -> None:
def downgrade() -> None:
conn = op.get_bind()
existing_ids_and_chosen_assistants = conn.execute(
sa.text("select id, chosen_assistants from public.user")
sa.text('select id, chosen_assistants from "user"')
)
op.drop_column(
"user",
@@ -59,7 +59,7 @@ def downgrade() -> None:
for id, chosen_assistants in existing_ids_and_chosen_assistants:
conn.execute(
sa.text(
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
'update "user" set chosen_assistants = :chosen_assistants where id = :id'
),
{"chosen_assistants": chosen_assistants, "id": id},
)

View File

@@ -0,0 +1,26 @@
"""add_deployment_name_to_llmprovider
Revision ID: e4334d5b33ba
Revises: ac5eaac849f9
Create Date: 2024-10-04 09:52:34.896867
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "e4334d5b33ba"
down_revision = "ac5eaac849f9"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"llm_provider", sa.Column("deployment_name", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("llm_provider", "deployment_name")

View File

@@ -0,0 +1,3 @@
These files are for public table migrations when operating with multi tenancy.
If you are not a Danswer developer, you can ignore this directory entirely.

View File

@@ -0,0 +1,111 @@
import asyncio
from logging.config import fileConfig
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.schema import SchemaItem
from alembic import context
from danswer.db.engine import build_connection_string
from danswer.db.models import PublicBase
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None and config.attributes.get(
"configure_logger", True
):
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = [PublicBase.metadata]
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
def include_object(
object: SchemaItem,
name: str,
type_: str,
reflected: bool,
compare_to: SchemaItem | None,
) -> bool:
if type_ == "table" and name in EXCLUDE_TABLES:
return False
return True
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = build_connection_string()
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore
include_object=include_object,
) # type: ignore
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
"""In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = create_async_engine(
build_connection_string(),
poolclass=pool.NullPool,
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
asyncio.run(run_async_migrations())
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,24 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,24 @@
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "14a83a331951"
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"user_tenant_mapping",
sa.Column("email", sa.String(), nullable=False),
sa.Column("tenant_id", sa.String(), nullable=False),
sa.UniqueConstraint("email", "tenant_id", name="uq_user_tenant"),
sa.UniqueConstraint("email", name="uq_email"),
schema="public",
)
def downgrade() -> None:
op.drop_table("user_tenant_mapping", schema="public")

View File

@@ -34,6 +34,7 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
has_web_login: bool | None = True
tenant_id: str | None = None
class UserUpdate(schemas.BaseUserUpdate):

View File

@@ -5,18 +5,23 @@ from datetime import datetime
from datetime import timezone
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import jwt
from email_validator import EmailNotValidError
from email_validator import EmailUndeliverableError
from email_validator import validate_email
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi import Request
from fastapi import Response
from fastapi import status
from fastapi.responses import RedirectResponse
from fastapi.security import OAuth2PasswordRequestForm
from fastapi_users import BaseUserManager
from fastapi_users import exceptions
@@ -26,11 +31,25 @@ from fastapi_users import schemas
from fastapi_users import UUIDIDMixin
from fastapi_users.authentication import AuthenticationBackend
from fastapi_users.authentication import CookieTransport
from fastapi_users.authentication import JWTStrategy
from fastapi_users.authentication import Strategy
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
from fastapi_users.authentication.strategy.db import DatabaseStrategy
from fastapi_users.exceptions import UserAlreadyExists
from fastapi_users.jwt import decode_jwt
from fastapi_users.jwt import generate_jwt
from fastapi_users.jwt import SecretType
from fastapi_users.manager import UserManagerDependency
from fastapi_users.openapi import OpenAPIResponseType
from fastapi_users.router.common import ErrorCode
from fastapi_users.router.common import ErrorModel
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import attributes
from sqlalchemy.orm import Session
from danswer.auth.invited_users import get_invited_users
@@ -38,11 +57,11 @@ from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
from danswer.auth.schemas import UserUpdate
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DATA_PLANE_SECRET
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import EMAIL_FROM
from danswer.configs.app_configs import EXPECTED_API_KEY
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from danswer.configs.app_configs import SECRET_JWT_KEY
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.configs.app_configs import SMTP_PASS
from danswer.configs.app_configs import SMTP_PORT
@@ -60,15 +79,21 @@ from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.auth import SQLAlchemyUserAdminDB
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.engine import get_session
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User
from danswer.db.models import UserTenantMapping
from danswer.db.users import get_user_by_email
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import current_tenant_id
logger = setup_logger()
@@ -118,7 +143,10 @@ def verify_email_is_invited(email: str) -> None:
if not email:
raise PermissionError("Email must be specified")
email_info = validate_email(email) # can raise EmailNotValidError
try:
email_info = validate_email(email)
except EmailUndeliverableError:
raise PermissionError("Email is not valid")
for email_whitelist in whitelist:
try:
@@ -136,8 +164,8 @@ def verify_email_is_invited(email: str) -> None:
raise PermissionError("User not on allowed user whitelist")
def verify_email_in_whitelist(email: str) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
if not get_user_by_email(email, db_session):
verify_email_is_invited(email)
@@ -157,6 +185,20 @@ def verify_email_domain(email: str) -> None:
)
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return "public"
# Implement logic to get tenant_id from the mapping table
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
)
tenant_id = result.scalar_one_or_none()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id
def send_user_verification_email(
user_email: str,
token: str,
@@ -191,35 +233,83 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False,
request: Optional[Request] = None,
) -> User:
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if hasattr(user_create, "role"):
user_count = await get_user_count()
if user_count == 0 or user_create.email in get_default_admin_user_emails():
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
user = None
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
if (
not user.has_web_login
and hasattr(user_create, "has_web_login")
and user_create.has_web_login
):
user_update = UserUpdate(
password=user_create.password,
has_web_login=True,
role=user_create.role,
is_verified=user_create.is_verified,
)
user = await self.update(user_update, user)
else:
raise exceptions.UserAlreadyExists()
return user
tenant_id = (
get_tenant_id_for_email(user_create.email) if MULTI_TENANT else "public"
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
if not tenant_id:
raise HTTPException(
status_code=401, detail="User does not belong to an organization"
)
async with get_async_session_with_tenant(tenant_id) as db_session:
token = current_tenant_id.set(tenant_id)
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
self.user_db = tenant_user_db
self.database = tenant_user_db
if hasattr(user_create, "role"):
user_count = await get_user_count()
if (
user_count == 0
or user_create.email in get_default_admin_user_emails()
):
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
user = None
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
if (
not user.has_web_login
and hasattr(user_create, "has_web_login")
and user_create.has_web_login
):
user_update = UserUpdate(
password=user_create.password,
has_web_login=True,
role=user_create.role,
is_verified=user_create.is_verified,
)
user = await self.update(user_update, user)
else:
raise exceptions.UserAlreadyExists()
current_tenant_id.reset(token)
return user
async def on_after_login(
self,
user: User,
request: Request | None = None,
response: Response | None = None,
) -> None:
if response is None or not MULTI_TENANT:
return
tenant_id = get_tenant_id_for_email(user.email)
tenant_token = jwt.encode(
{"tenant_id": tenant_id}, SECRET_JWT_KEY, algorithm="HS256"
)
response.set_cookie(
key="tenant_details",
value=tenant_token,
httponly=True,
secure=WEB_DOMAIN.startswith("https"),
samesite="lax",
)
async def oauth_callback(
self: "BaseUserManager[models.UOAP, models.ID]",
@@ -234,45 +324,111 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> models.UOAP:
verify_email_in_whitelist(account_email)
verify_email_domain(account_email)
user = await super().oauth_callback( # type: ignore
oauth_name=oauth_name,
access_token=access_token,
account_id=account_id,
account_email=account_email,
expires_at=expires_at,
refresh_token=refresh_token,
request=request,
associate_by_email=associate_by_email,
is_verified_by_default=is_verified_by_default,
)
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
# re-authenticate that frequently, so by default this is disabled
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
await self.user_db.update(user, update_dict={"oidc_expiry": oidc_expiry})
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
# otherwise, the oidc expiry will always be old, and the user will never be able to login
if user.oidc_expiry and not TRACK_EXTERNAL_IDP_EXPIRY:
await self.user_db.update(user, update_dict={"oidc_expiry": None})
# Handle case where user has used product outside of web and is now creating an account through web
if not user.has_web_login:
await self.user_db.update(
user,
update_dict={
"is_verified": is_verified_by_default,
"has_web_login": True,
},
# Get tenant_id from mapping table
try:
tenant_id = (
get_tenant_id_for_email(account_email) if MULTI_TENANT else "public"
)
user.is_verified = is_verified_by_default
user.has_web_login = True
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
return user
if not tenant_id:
raise HTTPException(status_code=401, detail="User not found")
token = None
async with get_async_session_with_tenant(tenant_id) as db_session:
token = current_tenant_id.set(tenant_id)
verify_email_in_whitelist(account_email, tenant_id)
verify_email_domain(account_email)
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
self.user_db = tenant_user_db
self.database = tenant_user_db # type: ignore
oauth_account_dict = {
"oauth_name": oauth_name,
"access_token": access_token,
"account_id": account_id,
"account_email": account_email,
"expires_at": expires_at,
"refresh_token": refresh_token,
}
try:
# Attempt to get user by OAuth account
user = await self.get_by_oauth_account(oauth_name, account_id)
except exceptions.UserNotExists:
try:
# Attempt to get user by email
user = await self.get_by_email(account_email)
if not associate_by_email:
raise exceptions.UserAlreadyExists()
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
# If user not found by OAuth account or email, create a new user
except exceptions.UserNotExists:
password = self.password_helper.generate()
user_dict = {
"email": account_email,
"hashed_password": self.password_helper.hash(password),
"is_verified": is_verified_by_default,
}
user = await self.user_db.create(user_dict)
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
await self.on_after_register(user, request)
else:
for existing_oauth_account in user.oauth_accounts:
if (
existing_oauth_account.account_id == account_id
and existing_oauth_account.oauth_name == oauth_name
):
user = await self.user_db.update_oauth_account(
user, existing_oauth_account, oauth_account_dict
)
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
# re-authenticate that frequently, so by default this is disabled
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
await self.user_db.update(
user, update_dict={"oidc_expiry": oidc_expiry}
)
# Handle case where user has used product outside of web and is now creating an account through web
if not user.has_web_login: # type: ignore
await self.user_db.update(
user,
{
"is_verified": is_verified_by_default,
"has_web_login": True,
},
)
user.is_verified = is_verified_by_default
user.has_web_login = True # type: ignore
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
# otherwise, the oidc expiry will always be old, and the user will never be able to login
if (
user.oidc_expiry is not None # type: ignore
and not TRACK_EXTERNAL_IDP_EXPIRY
):
await self.user_db.update(user, {"oidc_expiry": None})
user.oidc_expiry = None # type: ignore
if token:
current_tenant_id.reset(token)
return user
async def on_after_register(
self, user: User, request: Optional[Request] = None
@@ -303,28 +459,50 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
async def authenticate(
self, credentials: OAuth2PasswordRequestForm
) -> Optional[User]:
try:
user = await self.get_by_email(credentials.username)
except exceptions.UserNotExists:
email = credentials.username
# Get tenant_id from mapping table
tenant_id = get_tenant_id_for_email(email)
if not tenant_id:
# User not found in mapping
self.password_helper.hash(credentials.password)
return None
if not user.has_web_login:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
# Create a tenant-specific session
async with get_async_session_with_tenant(tenant_id) as tenant_session:
tenant_user_db: SQLAlchemyUserDatabase = SQLAlchemyUserDatabase(
tenant_session, User
)
self.user_db = tenant_user_db
verified, updated_password_hash = self.password_helper.verify_and_update(
credentials.password, user.hashed_password
)
if not verified:
return None
# Proceed with authentication
try:
user = await self.get_by_email(email)
if updated_password_hash is not None:
await self.user_db.update(user, {"hashed_password": updated_password_hash})
except exceptions.UserNotExists:
self.password_helper.hash(credentials.password)
return None
return user
has_web_login = attributes.get_attribute(user, "has_web_login")
if not has_web_login:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
)
verified, updated_password_hash = self.password_helper.verify_and_update(
credentials.password, user.hashed_password
)
if not verified:
return None
if updated_password_hash is not None:
await self.user_db.update(
user, {"hashed_password": updated_password_hash}
)
return user
async def get_user_manager(
@@ -339,20 +517,26 @@ cookie_transport = CookieTransport(
)
def get_jwt_strategy() -> JWTStrategy:
return JWTStrategy(
secret=USER_AUTH_SECRET,
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
)
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> DatabaseStrategy:
strategy = DatabaseStrategy(
return DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
)
return strategy
auth_backend = AuthenticationBackend(
name="database",
name="jwt" if MULTI_TENANT else "database",
transport=cookie_transport,
get_strategy=get_database_strategy,
)
get_strategy=get_jwt_strategy if MULTI_TENANT else get_database_strategy, # type: ignore
) # type: ignore
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
@@ -366,9 +550,11 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
This way the login router does not need to be included
"""
router = APIRouter()
get_current_user_token = self.authenticator.current_user_token(
active=True, verified=requires_verification
)
logout_responses: OpenAPIResponseType = {
**{
status.HTTP_401_UNAUTHORIZED: {
@@ -415,8 +601,8 @@ async def optional_user_(
async def optional_user(
request: Request,
user: User | None = Depends(optional_fastapi_current_user),
db_session: Session = Depends(get_session),
user: User | None = Depends(optional_fastapi_current_user),
) -> User | None:
versioned_fetch_user = fetch_versioned_implementation(
"danswer.auth.users", "optional_user_"
@@ -509,26 +695,184 @@ def get_default_admin_user_emails_() -> list[str]:
return []
async def control_plane_dep(request: Request) -> None:
api_key = request.headers.get("X-API-KEY")
if api_key != EXPECTED_API_KEY:
logger.warning("Invalid API key")
raise HTTPException(status_code=401, detail="Invalid API key")
STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
logger.warning("Invalid authorization header")
raise HTTPException(status_code=401, detail="Invalid authorization header")
token = auth_header.split(" ")[1]
try:
payload = jwt.decode(token, DATA_PLANE_SECRET, algorithms=["HS256"])
if payload.get("scope") != "tenant:create":
logger.warning("Insufficient permissions")
raise HTTPException(status_code=403, detail="Insufficient permissions")
except jwt.ExpiredSignatureError:
logger.warning("Token has expired")
raise HTTPException(status_code=401, detail="Token has expired")
except jwt.InvalidTokenError:
logger.warning("Invalid token")
raise HTTPException(status_code=401, detail="Invalid token")
class OAuth2AuthorizeResponse(BaseModel):
authorization_url: str
def generate_state_token(
data: Dict[str, str], secret: SecretType, lifetime_seconds: int = 3600
) -> str:
data["aud"] = STATE_TOKEN_AUDIENCE
return generate_jwt(data, secret, lifetime_seconds)
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
def create_danswer_oauth_router(
oauth_client: BaseOAuth2,
backend: AuthenticationBackend,
state_secret: SecretType,
redirect_url: Optional[str] = None,
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> APIRouter:
return get_oauth_router(
oauth_client,
backend,
get_user_manager,
state_secret,
redirect_url,
associate_by_email,
is_verified_by_default,
)
def get_oauth_router(
oauth_client: BaseOAuth2,
backend: AuthenticationBackend,
get_user_manager: UserManagerDependency[models.UP, models.ID],
state_secret: SecretType,
redirect_url: Optional[str] = None,
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> APIRouter:
"""Generate a router with the OAuth routes."""
router = APIRouter()
callback_route_name = f"oauth:{oauth_client.name}.{backend.name}.callback"
if redirect_url is not None:
oauth2_authorize_callback = OAuth2AuthorizeCallback(
oauth_client,
redirect_url=redirect_url,
)
else:
oauth2_authorize_callback = OAuth2AuthorizeCallback(
oauth_client,
route_name=callback_route_name,
)
@router.get(
"/authorize",
name=f"oauth:{oauth_client.name}.{backend.name}.authorize",
response_model=OAuth2AuthorizeResponse,
)
async def authorize(
request: Request, scopes: List[str] = Query(None)
) -> OAuth2AuthorizeResponse:
if redirect_url is not None:
authorize_redirect_url = redirect_url
else:
authorize_redirect_url = str(request.url_for(callback_route_name))
next_url = request.query_params.get("next", "/")
state_data: Dict[str, str] = {"next_url": next_url}
state = generate_state_token(state_data, state_secret)
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url,
state,
scopes,
)
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
@router.get(
"/callback",
name=callback_route_name,
description="The response varies based on the authentication backend used.",
responses={
status.HTTP_400_BAD_REQUEST: {
"model": ErrorModel,
"content": {
"application/json": {
"examples": {
"INVALID_STATE_TOKEN": {
"summary": "Invalid state token.",
"value": None,
},
ErrorCode.LOGIN_BAD_CREDENTIALS: {
"summary": "User is inactive.",
"value": {"detail": ErrorCode.LOGIN_BAD_CREDENTIALS},
},
}
}
},
},
},
)
async def callback(
request: Request,
access_token_state: Tuple[OAuth2Token, str] = Depends(
oauth2_authorize_callback
),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
) -> RedirectResponse:
token, state = access_token_state
account_id, account_email = await oauth_client.get_id_email(
token["access_token"]
)
if account_email is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL,
)
try:
state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE])
except jwt.DecodeError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
next_url = state_data.get("next_url", "/")
# Authenticate user
try:
user = await user_manager.oauth_callback(
oauth_client.name,
token["access_token"],
account_id,
account_email,
token.get("expires_at"),
token.get("refresh_token"),
request,
associate_by_email=associate_by_email,
is_verified_by_default=is_verified_by_default,
)
except UserAlreadyExists:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.OAUTH_USER_ALREADY_EXISTS,
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.LOGIN_BAD_CREDENTIALS,
)
# Login user
response = await backend.login(strategy, user)
await user_manager.on_after_login(user, request, response)
# Prepare redirect response
redirect_response = RedirectResponse(next_url, status_code=302)
# Copy headers and other attributes from 'response' to 'redirect_response'
for header_name, header_value in response.headers.items():
redirect_response.headers[header_name] = header_value
if hasattr(response, "body"):
redirect_response.body = response.body
if hasattr(response, "status_code"):
redirect_response.status_code = response.status_code
if hasattr(response, "media_type"):
redirect_response.media_type = response.media_type
return redirect_response
return router

View File

@@ -1,9 +1,10 @@
import logging
import multiprocessing
import time
from datetime import timedelta
from typing import Any
import redis
import sentry_sdk
from celery import bootsteps # type: ignore
from celery import Celery
from celery import current_task
@@ -11,49 +12,86 @@ from celery import signals
from celery import Task
from celery.exceptions import WorkerShutdown
from celery.signals import beat_init
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
from celery.states import READY_STATES
from celery.utils.log import get_task_logger
from sentry_sdk.integrations.celery import CeleryIntegration
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.background.celery.celery_utils import get_all_tenant_ids
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import SqlEngine
from danswer.db.search_settings import get_current_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import ColoredFormatter
from danswer.utils.logger import PlainFormatter
from danswer.utils.logger import setup_logger
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
# use this within celery tasks to get celery task specific logging
task_logger = get_task_logger(__name__)
if SENTRY_DSN:
sentry_sdk.init(
dsn=SENTRY_DSN,
integrations=[CeleryIntegration()],
traces_sample_rate=0.5,
)
logger.info("Sentry initialized")
else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
celery_app = Celery(__name__)
celery_app.config_from_object(
"danswer.background.celery.celeryconfig"
) # Load configuration from 'celeryconfig.py'
@signals.task_postrun.connect
def celery_task_postrun(
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
tenant_id: str | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
pass
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict[str, Any] | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
@@ -65,12 +103,24 @@ def celery_task_postrun(
This function runs after any task completes (both success and failure)
Note that this signal does not fire on a task that failed to complete and is going
to be retried.
This also does not fire if a worker with acks_late=False crashes (which all of our
long running workers are)
"""
if not task:
return
task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}")
# logger.debug(f"Result: {retval}")
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
if not kwargs:
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
tenant_id = None
else:
tenant_id = kwargs.get("tenant_id")
task_logger.debug(
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
f"{f'for tenant_id={tenant_id}' if tenant_id else ''}"
)
if state not in READY_STATES:
return
@@ -78,7 +128,7 @@ def celery_task_postrun(
if not task_id:
return
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
@@ -87,32 +137,38 @@ def celery_task_postrun(
if task_id.startswith(RedisDocumentSet.PREFIX):
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
if document_set_id is not None:
rds = RedisDocumentSet(document_set_id)
rds = RedisDocumentSet(int(document_set_id))
r.srem(rds.taskset_key, task_id)
return
if task_id.startswith(RedisUserGroup.PREFIX):
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
if usergroup_id is not None:
rug = RedisUserGroup(usergroup_id)
rug = RedisUserGroup(int(usergroup_id))
r.srem(rug.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorDeletion.PREFIX):
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
if cc_pair_id is not None:
rcd = RedisConnectorDeletion(cc_pair_id)
rcd = RedisConnectorDeletion(int(cc_pair_id))
r.srem(rcd.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
if cc_pair_id is not None:
rcp = RedisConnectorPruning(cc_pair_id)
rcp = RedisConnectorPruning(int(cc_pair_id))
r.srem(rcp.taskset_key, task_id)
return
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
"""The first signal sent on celery worker startup"""
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
@beat_init.connect
def on_beat_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
@@ -121,8 +177,12 @@ def on_beat_init(sender: Any, **kwargs: Any) -> None:
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
# decide some initial startup settings based on the celery worker's hostname
# (set at the command line)
# (set at the command line)'
hostname = sender.hostname
if hostname.startswith("light"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
@@ -130,131 +190,171 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
elif hostname.startswith("heavy"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
elif hostname.startswith("indexing"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
# TODO: why is this necessary for the indexer to do?
with get_session_with_tenant(tenant_id) as db_session:
check_index_swap(db_session=db_session)
search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
if search_settings.provider_type is None:
logger.notice(
"Running a first inference to warm up embedding model"
)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(
embedding_model=embedding_model,
)
logger.notice("First inference complete.")
else:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
r = get_redis_client()
if not hasattr(sender, "primary_worker_locks"):
sender.primary_worker_locks = {}
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
time_start = time.monotonic()
logger.info("Redis: Readiness check starting.")
while True:
try:
if r.ping():
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
logger.info(
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Redis: Readiness check did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
time.sleep(WAIT_INTERVAL)
logger.info("Redis: Readiness check succeeded. Continuing...")
tenant_ids = get_all_tenant_ids()
if not celery_is_worker_primary(sender):
logger.info("Running as a secondary celery worker.")
logger.info("Waiting for primary worker to be ready...")
time_start = time.monotonic()
while True:
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
break
time.monotonic()
time_elapsed = time.monotonic() - time_start
logger.info(
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Primary worker was not ready within the timeout. "
f"({WAIT_LIMIT} seconds). Exiting..."
for tenant_id in tenant_ids:
r = get_redis_client(tenant_id=tenant_id)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
time_start = time.monotonic()
logger.notice("Redis: Readiness check starting.")
while True:
# Log all the locks in Redis
all_locks = r.keys("*")
logger.notice(f"Current Redis locks: {all_locks}")
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
break
time_elapsed = time.monotonic() - time_start
logger.info(
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
logger.error(msg)
raise WorkerShutdown(msg)
if time_elapsed > WAIT_LIMIT:
msg = (
"Redis: Readiness check did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
time.sleep(WAIT_INTERVAL)
logger.info("Wait for primary worker completed successfully. Continuing...")
return # Exit the function for secondary workers
time.sleep(WAIT_INTERVAL)
for tenant_id in tenant_ids:
r = get_redis_client(tenant_id=tenant_id)
logger.info("Wait for primary worker completed successfully. Continuing...")
return
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
logger.info("Running as the primary celery worker.")
time_start = time.monotonic()
logger.info("Running as the primary celery worker.")
# This is singleton work that should be done on startup exactly once
# by the primary worker
r = get_redis_client()
# This is singleton work that should be done on startup exactly once
# by the primary worker
r = get_redis_client(tenant_id=tenant_id)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
# this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
# this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")
logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")
sender.primary_worker_lock = lock
sender.primary_worker_locks[tenant_id] = lock
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
r.delete(key)
# @worker_process_init.connect
# def on_worker_process_init(sender: Any, **kwargs: Any) -> None:
# """This only runs inside child processes when the worker is in pool=prefork mode.
# This may be technically unnecessary since we're finding prefork pools to be
# unstable and currently aren't planning on using them."""
# logger.info("worker_process_init signal received.")
# SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
# SqlEngine.init_engine(pool_size=5, max_overflow=0)
# # https://stackoverflow.com/questions/43944787/sqlalchemy-celery-with-scoped-session-error
# SqlEngine.get_engine().dispose(close=False)
@worker_ready.connect
@@ -267,14 +367,15 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
if not celery_is_worker_primary(sender):
return
if not sender.primary_worker_lock:
if not hasattr(sender, "primary_worker_locks"):
return
logger.info("Releasing primary worker lock.")
lock = sender.primary_worker_lock
if lock.owned():
lock.release()
sender.primary_worker_lock = None
for tenant_id, lock in sender.primary_worker_locks.items():
logger.info(f"Releasing primary worker lock for tenant {tenant_id}.")
if lock.owned():
lock.release()
sender.primary_worker_locks = {}
class CeleryTaskPlainFormatter(PlainFormatter):
@@ -304,7 +405,7 @@ def on_setup_logging(
# TODO: could unhardcode format and colorize and accept these as options from
# celery's config
# reformats celery's worker logger
# reformats the root logger
root_logger = logging.getLogger()
root_handler = logging.StreamHandler() # Set up a handler for the root logger
@@ -349,17 +450,18 @@ def on_setup_logging(
class HubPeriodicTask(bootsteps.StartStopStep):
"""Regularly reacquires the primary worker lock outside of the task queue.
"""Regularly reacquires the primary worker locks for all tenants outside of the task queue.
Use the task_logger in this class to avoid double logging.
This cannot be done inside a regular beat task because it must run on schedule and
a queue of existing work would starve the task from running.
"""
# it's unclear to me whether using the hub's timer or the bootstep timer is better
# Requires the Hub component
requires = {"celery.worker.components:Hub"}
def __init__(self, worker: Any, **kwargs: Any) -> None:
super().__init__(worker, **kwargs)
self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds
self.task_tref = None
@@ -378,42 +480,58 @@ class HubPeriodicTask(bootsteps.StartStopStep):
def run_periodic_task(self, worker: Any) -> None:
try:
if not worker.primary_worker_lock:
if not celery_is_worker_primary(worker):
return
if not hasattr(worker, "primary_worker_lock"):
if not hasattr(worker, "primary_worker_locks"):
return
r = get_redis_client()
# Retrieve all tenant IDs
tenant_ids = get_all_tenant_ids()
lock: redis.lock.Lock = worker.primary_worker_lock
for tenant_id in tenant_ids:
lock = worker.primary_worker_locks.get(tenant_id)
if not lock:
continue # Skip if no lock for this tenant
if lock.owned():
task_logger.debug("Reacquiring primary worker lock.")
lock.reacquire()
else:
task_logger.warning(
"Full acquisition of primary worker lock. "
"Reasons could be computer sleep or a clock change."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
r = get_redis_client(tenant_id=tenant_id)
task_logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info("Primary worker lock: Acquire succeeded.")
if lock.owned():
task_logger.debug(
f"Reacquiring primary worker lock for tenant {tenant_id}."
)
lock.reacquire()
else:
task_logger.error("Primary worker lock: Acquire failed!")
raise TimeoutError("Primary worker lock could not be acquired!")
task_logger.warning(
f"Full acquisition of primary worker lock for tenant {tenant_id}. "
"Reasons could be worker restart or lock expiration."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
worker.primary_worker_lock = lock
except Exception:
task_logger.exception("HubPeriodicTask.run_periodic_task exceptioned.")
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire starting."
)
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire succeeded."
)
worker.primary_worker_locks[tenant_id] = lock
else:
task_logger.error(
f"Primary worker lock for tenant {tenant_id}: Acquire failed!"
)
raise TimeoutError(
f"Primary worker lock for tenant {tenant_id} could not be acquired!"
)
except Exception as e:
task_logger.error(f"Error in periodic task: {e}")
def stop(self, worker: Any) -> None:
# Cancel the scheduled task when the worker stops
@@ -427,6 +545,7 @@ celery_app.steps["worker"].add(HubPeriodicTask)
celery_app.autodiscover_tasks(
[
"danswer.background.celery.tasks.connector_deletion",
"danswer.background.celery.tasks.indexing",
"danswer.background.celery.tasks.periodic",
"danswer.background.celery.tasks.pruning",
"danswer.background.celery.tasks.shared",
@@ -437,48 +556,64 @@ celery_app.autodiscover_tasks(
#####
# Celery Beat (Periodic Tasks) Settings
#####
celery_app.conf.beat_schedule = {
"check-for-vespa-sync": {
tenant_ids = get_all_tenant_ids()
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": "check_for_vespa_sync_task",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
}
celery_app.conf.beat_schedule.update(
{
"check-for-connector-deletion-task": {
"task": "check_for_connector_deletion_task",
# don't need to check too often, since we kick off a deletion initially
# during the API call that actually marks the CC pair for deletion
"schedule": timedelta(seconds=60),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
}
)
celery_app.conf.beat_schedule.update(
"name": "check-for-connector-deletion",
"task": "check_for_connector_deletion_task",
"schedule": timedelta(seconds=60),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"check-for-prune": {
"task": "check_for_prune_task_2",
"schedule": timedelta(seconds=60),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
}
)
celery_app.conf.beat_schedule.update(
"name": "check-for-indexing",
"task": "check_for_indexing",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"kombu-message-cleanup": {
"task": "kombu_message_cleanup_task",
"schedule": timedelta(seconds=3600),
"options": {"priority": DanswerCeleryPriority.LOWEST},
},
}
)
celery_app.conf.beat_schedule.update(
"name": "check-for-prune",
"task": "check_for_pruning",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"monitor-vespa-sync": {
"task": "monitor_vespa_sync",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
}
)
"name": "kombu-message-cleanup",
"task": "kombu_message_cleanup_task",
"schedule": timedelta(seconds=3600),
"options": {"priority": DanswerCeleryPriority.LOWEST},
},
{
"name": "monitor-vespa-sync",
"task": "monitor_vespa_sync",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
]
# Build the celery beat schedule dynamically
beat_schedule = {}
for id in tenant_ids:
for task in tasks_to_schedule:
task_name = f"{task['name']}-{id}" # Unique name for each scheduled task
beat_schedule[task_name] = {
"task": task["task"],
"schedule": task["schedule"],
"options": task["options"],
"kwargs": {"tenant_id": id}, # Must pass tenant_id as an argument
}
# Include any existing beat schedules
existing_beat_schedule = celery_app.conf.beat_schedule or {}
beat_schedule.update(existing_beat_schedule)
# Update the Celery app configuration once
celery_app.conf.beat_schedule = beat_schedule

View File

@@ -21,6 +21,7 @@ from danswer.db.document import (
)
from danswer.db.document_set import construct_document_select_by_docset
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import global_version
class RedisObjectHelper(ABC):
@@ -28,8 +29,8 @@ class RedisObjectHelper(ABC):
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int):
self._id: int = id
def __init__(self, id: str):
self._id: str = id
@property
def task_id_prefix(self) -> str:
@@ -46,7 +47,7 @@ class RedisObjectHelper(ABC):
return f"{self.TASKSET_PREFIX}_{self._id}"
@staticmethod
def get_id_from_fence_key(key: str) -> int | None:
def get_id_from_fence_key(key: str) -> str | None:
"""
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
@@ -60,15 +61,11 @@ class RedisObjectHelper(ABC):
if len(parts) != 3:
return None
try:
object_id = int(parts[2])
except ValueError:
return None
object_id = parts[2]
return object_id
@staticmethod
def get_id_from_task_id(task_id: str) -> int | None:
def get_id_from_task_id(task_id: str) -> str | None:
"""
Extracts the object ID from a task ID string.
@@ -92,11 +89,7 @@ class RedisObjectHelper(ABC):
if len(parts) != 3:
return None
try:
object_id = int(parts[1])
except ValueError:
return None
object_id = parts[1]
return object_id
@abstractmethod
@@ -106,6 +99,7 @@ class RedisObjectHelper(ABC):
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
pass
@@ -115,17 +109,21 @@ class RedisDocumentSet(RedisObjectHelper):
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
stmt = construct_document_select_by_docset(self._id, current_only=False)
stmt = construct_document_select_by_docset(int(self._id), current_only=False)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
@@ -145,7 +143,7 @@ class RedisDocumentSet(RedisObjectHelper):
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id),
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
@@ -161,17 +159,24 @@ class RedisUserGroup(RedisObjectHelper):
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
if not global_version.is_ee_version():
return 0
try:
construct_document_select_by_usergroup = fetch_versioned_implementation(
"danswer.db.user_group",
@@ -180,7 +185,7 @@ class RedisUserGroup(RedisObjectHelper):
except ModuleNotFoundError:
return 0
stmt = construct_document_select_by_usergroup(self._id)
stmt = construct_document_select_by_usergroup(int(self._id))
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
@@ -200,7 +205,7 @@ class RedisUserGroup(RedisObjectHelper):
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id),
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
@@ -212,13 +217,19 @@ class RedisUserGroup(RedisObjectHelper):
class RedisConnectorCredentialPair(RedisObjectHelper):
"""This class differs from the default in that the taskset used spans
"""This class is used to scan documents by cc_pair in the db and collect them into
a unified set for syncing.
It differs from the other redis helpers in that the taskset used spans
all connectors and is not per connector."""
PREFIX = "connectorsync"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
@classmethod
def get_fence_key(cls) -> str:
return RedisConnectorCredentialPair.FENCE_PREFIX
@@ -240,11 +251,12 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
@@ -274,7 +286,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id),
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
@@ -290,17 +302,21 @@ class RedisConnectorDeletion(RedisObjectHelper):
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
@@ -332,6 +348,7 @@ class RedisConnectorDeletion(RedisObjectHelper):
document_id=doc.id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
@@ -376,9 +393,7 @@ class RedisConnectorPruning(RedisObjectHelper):
) # a signal that the generator has finished
def __init__(self, id: int) -> None:
"""id: the cc_pair_id of the connector credential pair"""
super().__init__(id)
super().__init__(str(id))
self.documents_to_prune: set[str] = set()
@property
@@ -405,11 +420,12 @@ class RedisConnectorPruning(RedisObjectHelper):
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
@@ -438,6 +454,7 @@ class RedisConnectorPruning(RedisObjectHelper):
document_id=doc_id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
@@ -451,7 +468,7 @@ class RedisConnectorPruning(RedisObjectHelper):
def is_pruning(self, db_session: Session, redis_client: Redis) -> bool:
"""A single example of a helper method being refactored into the redis helper"""
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=self._id, db_session=db_session
cc_pair_id=int(self._id), db_session=db_session
)
if not cc_pair:
raise ValueError(f"cc_pair_id {self._id} does not exist.")
@@ -462,6 +479,66 @@ class RedisConnectorPruning(RedisObjectHelper):
return False
class RedisConnectorIndexing(RedisObjectHelper):
"""Celery will kick off a long running indexing task to crawl the connector and
find any new or updated docs docs, which will each then get a new sync task or be
indexed inline.
ID should be a concatenation of cc_pair_id and search_setting_id, delimited by "/".
e.g. "2/5"
"""
PREFIX = "connectorindexing"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
SUBTASK_PREFIX = PREFIX + "+sub"
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # a signal that contains generator progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # a signal that the generator has finished
def __init__(self, cc_pair_id: int, search_settings_id: int) -> None:
super().__init__(f"{cc_pair_id}/{search_settings_id}")
@property
def generator_lock_key(self) -> str:
return f"{self.GENERATOR_LOCK_PREFIX}_{self._id}"
@property
def generator_task_id_prefix(self) -> str:
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
@property
def generator_progress_key(self) -> str:
# example: connectorpruning_generator_progress_1
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
@property
def generator_complete_key(self) -> str:
# example: connectorpruning_generator_complete_1
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
@property
def subtask_id_prefix(self) -> str:
return f"{self.SUBTASK_PREFIX}_{self._id}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
return None
def celery_get_queue_length(queue: str, r: Redis) -> int:
"""This is a redis specific way to get the length of a celery queue.
It is priority aware and knows how to count across the multiple redis lists

View File

@@ -3,10 +3,13 @@ from datetime import datetime
from datetime import timezone
from typing import Any
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.constants import TENANT_ID_PREFIX
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
@@ -16,6 +19,7 @@ from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import Document
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import TaskStatus
from danswer.db.models import TaskQueueState
from danswer.redis.redis_pool import get_redis_client
@@ -27,7 +31,10 @@ logger = setup_logger()
def _get_deletion_status(
connector_id: int, credential_id: int, db_session: Session
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str | None = None,
) -> TaskQueueState | None:
"""We no longer store TaskQueueState in the DB for a deletion attempt.
This function populates TaskQueueState by just checking redis.
@@ -40,7 +47,7 @@ def _get_deletion_status(
rcd = RedisConnectorDeletion(cc_pair.id)
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
if not r.exists(rcd.fence_key):
return None
@@ -50,9 +57,14 @@ def _get_deletion_status(
def get_deletion_attempt_snapshot(
connector_id: int, credential_id: int, db_session: Session
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str | None = None,
) -> DeletionAttemptSnapshot | None:
deletion_task = _get_deletion_status(connector_id, credential_id, db_session)
deletion_task = _get_deletion_status(
connector_id, credential_id, db_session, tenant_id
)
if not deletion_task:
return None
@@ -124,10 +136,30 @@ def celery_is_worker_primary(worker: Any) -> bool:
for the celery worker, which can be done either in celeryconfig.py or on the
command line with '--hostname'."""
hostname = worker.hostname
if hostname.startswith("light"):
return False
if hostname.startswith("primary"):
return True
if hostname.startswith("heavy"):
return False
return False
return True
def get_all_tenant_ids() -> list[str] | list[None]:
if not MULTI_TENANT:
return [None]
with get_session_with_tenant(tenant_id="public") as session:
result = session.execute(
text(
"""
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')"""
)
)
tenant_ids = [row[0] for row in result]
valid_tenants = [
tenant
for tenant in tenant_ids
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
]
return valid_tenants

View File

@@ -1,4 +1,6 @@
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
import urllib.parse
from danswer.configs.app_configs import CELERY_BROKER_POOL_LIMIT
from danswer.configs.app_configs import CELERY_RESULT_EXPIRES
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY
@@ -17,7 +19,7 @@ CELERY_SEPARATOR = ":"
CELERY_PASSWORD_PART = ""
if REDIS_PASSWORD:
CELERY_PASSWORD_PART = f":{REDIS_PASSWORD}@"
CELERY_PASSWORD_PART = ":" + urllib.parse.quote(REDIS_PASSWORD, safe="") + "@"
REDIS_SCHEME = "redis"
@@ -39,6 +41,11 @@ result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PO
# can stall other tasks.
worker_prefetch_multiplier = 4
# Leaving this to the default of True may cause double logging since both our own app
# and celery think they are controlling the logger.
# TODO: Configure celery's logger entirely manually and set this to False
# worker_hijack_root_logger = False
broker_connection_retry_on_startup = True
broker_pool_limit = CELERY_BROKER_POOL_LIMIT

View File

@@ -12,7 +12,7 @@ from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_pool import get_redis_client
@@ -23,8 +23,8 @@ from danswer.redis.redis_pool import get_redis_client
soft_time_limit=JOB_TIMEOUT,
trail=False,
)
def check_for_connector_deletion_task() -> None:
r = get_redis_client()
def check_for_connector_deletion_task(*, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
@@ -36,11 +36,11 @@ def check_for_connector_deletion_task() -> None:
if not lock_beat.acquire(blocking=False):
return
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
try_generate_document_cc_pair_cleanup_tasks(
cc_pair, db_session, r, lock_beat
cc_pair, db_session, r, lock_beat, tenant_id
)
except SoftTimeLimitExceeded:
task_logger.info(
@@ -58,6 +58,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Note that syncing can still be required even if the number of sync tasks generated is zero.
@@ -90,7 +91,9 @@ def try_generate_document_cc_pair_cleanup_tasks(
task_logger.info(
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
)
tasks_generated = rcd.generate_tasks(celery_app, db_session, r, lock_beat)
tasks_generated = rcd.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
return None

View File

@@ -0,0 +1,455 @@
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from time import sleep
from typing import cast
from uuid import uuid4
from celery import shared_task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.tasks.shared.tasks import RedisConnectorIndexingFenceData
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import DocumentSource
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import IndexingStatus
from danswer.db.enums import IndexModelStatus
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
logger = setup_logger()
@shared_task(
name="check_for_indexing",
soft_time_limit=300,
)
def check_for_indexing(*, tenant_id: str | None) -> int | None:
tasks_created = 0
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
task_logger.info(f"Lock acquired for tenant (Y): {tenant_id}")
return None
else:
task_logger.info(f"Lock acquired for tenant (N): {tenant_id}")
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings = [primary_search_settings]
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings.append(secondary_search_settings)
cc_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
for search_settings_instance in search_settings:
rci = RedisConnectorIndexing(
cc_pair.id, search_settings_instance.id
)
if r.exists(rci.fence_key):
continue
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, search_settings_instance.id, db_session
)
if not _should_index(
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
secondary_index_building=len(search_settings) > 1,
db_session=db_session,
):
continue
# using a task queue and only allowing one task per cc_pair/search_setting
# prevents us from starving out certain attempts
attempt_id = try_creating_indexing_task(
cc_pair,
search_settings_instance,
False,
db_session,
r,
tenant_id,
)
if attempt_id:
task_logger.info(
f"Indexing queued: cc_pair_id={cc_pair.id} index_attempt_id={attempt_id}"
)
tasks_created += 1
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception")
finally:
if lock_beat.owned():
lock_beat.release()
return tasks_created
def _should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
secondary_index_building: bool,
db_session: Session,
) -> bool:
"""Checks various global settings and past indexing attempts to determine if
we should try to start indexing the cc pair / search setting combination.
Note that tactical checks such as preventing overlap with a currently running task
are not handled here.
Return True if we should try to index, False if not.
"""
connector = cc_pair.connector
# uncomment for debugging
# task_logger.info(f"_should_index: "
# f"cc_pair={cc_pair.id} "
# f"connector={cc_pair.connector_id} "
# f"refresh_freq={connector.refresh_freq}")
# don't kick off indexing for `NOT_APPLICABLE` sources
if connector.source == DocumentSource.NOT_APPLICABLE:
return False
# User can still manually create single indexing attempts via the UI for the
# currently in use index
if DISABLE_INDEX_UPDATE_ON_SWAP:
if (
search_settings_instance.status == IndexModelStatus.PRESENT
and secondary_index_building
):
return False
# When switching over models, always index at least once
if search_settings_instance.status == IndexModelStatus.FUTURE:
if last_index:
# No new index if the last index attempt succeeded
# Once is enough. The model will never be able to swap otherwise.
if last_index.status == IndexingStatus.SUCCESS:
return False
# No new index if the last index attempt is waiting to start
if last_index.status == IndexingStatus.NOT_STARTED:
return False
# No new index if the last index attempt is running
if last_index.status == IndexingStatus.IN_PROGRESS:
return False
else:
if (
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
): # Ingestion API
return False
return True
# If the connector is paused or is the ingestion API, don't index
# NOTE: during an embedding model switch over, the following logic
# is bypassed by the above check for a future model
if (
not cc_pair.status.is_active()
or connector.id == 0
or connector.source == DocumentSource.INGESTION_API
):
return False
# if no attempt has ever occurred, we should index regardless of refresh_freq
if not last_index:
return True
if connector.refresh_freq is None:
return False
current_db_time = get_db_current_time(db_session)
time_since_index = current_db_time - last_index.time_updated
if time_since_index.total_seconds() < connector.refresh_freq:
return False
return True
def try_creating_indexing_task(
cc_pair: ConnectorCredentialPair,
search_settings: SearchSettings,
reindex: bool,
db_session: Session,
r: Redis,
tenant_id: str | None,
) -> int | None:
"""Checks for any conditions that should block the indexing task from being
created, then creates the task.
Does not check for scheduling related conditions as this function
is used to trigger indexing immediately.
"""
LOCK_TIMEOUT = 30
# we need to serialize any attempt to trigger indexing since it can be triggered
# either via celery beat or manually (API call)
lock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
try:
rci = RedisConnectorIndexing(cc_pair.id, search_settings.id)
# skip if already indexing
if r.exists(rci.fence_key):
return None
# skip indexing if the cc_pair is deleting
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# add a long running generator task to the queue
r.delete(rci.generator_complete_key)
r.delete(rci.taskset_key)
custom_task_id = f"{rci.generator_task_id_prefix}_{uuid4()}"
# create the index attempt ... just for tracking purposes
index_attempt_id = create_index_attempt(
cc_pair.id,
search_settings.id,
from_beginning=reindex,
db_session=db_session,
)
result = celery_app.send_task(
"connector_indexing_proxy_task",
kwargs=dict(
index_attempt_id=index_attempt_id,
cc_pair_id=cc_pair.id,
search_settings_id=search_settings.id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_INDEXING,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
if not result:
return None
# set this only after all tasks have been added
fence_value = RedisConnectorIndexingFenceData(
index_attempt_id=index_attempt_id,
started=None,
submitted=datetime.now(timezone.utc),
celery_task_id=result.id,
)
r.set(rci.fence_key, fence_value.model_dump_json())
except Exception:
task_logger.exception("Unexpected exception")
return None
finally:
if lock.owned():
lock.release()
return index_attempt_id
@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True)
def connector_indexing_proxy_task(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
) -> None:
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
client = SimpleJobClient()
job = client.submit(
connector_indexing_task,
index_attempt_id,
cc_pair_id,
search_settings_id,
tenant_id,
global_version.is_ee_version(),
pure=False,
)
if not job:
return
while True:
sleep(10)
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
# do nothing for ongoing jobs that haven't been stopped
if not job.done():
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
if job.status == "error":
logger.error(job.exception())
job.release()
break
return
def connector_indexing_task(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
is_ee: bool,
) -> int | None:
"""Indexing task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list
acks_late must be set to False. Otherwise, celery's visibility timeout will
cause any task that runs longer than the timeout to be redispatched by the broker.
There appears to be no good workaround for this, so we need to handle redispatching
manually.
Returns None if the task did not run (possibly due to a conflict).
Otherwise, returns an int >= 0 representing the number of indexed docs.
"""
attempt = None
n_final_progress = 0
r = get_redis_client(tenant_id=tenant_id)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
lock = r.lock(
rci.generator_lock_key,
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"Indexing task already running, exiting...: "
f"cc_pair_id={cc_pair_id} search_settings_id={search_settings_id}"
)
# r.set(rci.generator_complete_key, HTTPStatus.CONFLICT.value)
return None
try:
with get_session_with_tenant(tenant_id) as db_session:
attempt = get_index_attempt(db_session, index_attempt_id)
if not attempt:
raise ValueError(
f"Index attempt not found: index_attempt_id={index_attempt_id}"
)
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
)
if not cc_pair:
raise ValueError(f"cc_pair not found: cc_pair_id={cc_pair_id}")
if not cc_pair.connector:
raise ValueError(
f"Connector not found: connector_id={cc_pair.connector_id}"
)
if not cc_pair.credential:
raise ValueError(
f"Credential not found: credential_id={cc_pair.credential_id}"
)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
# Define the callback function
def redis_increment_callback(amount: int) -> None:
lock.reacquire()
r.incrby(rci.generator_progress_key, amount)
run_indexing_entrypoint(
index_attempt_id,
tenant_id,
cc_pair_id,
is_ee,
progress_callback=redis_increment_callback,
)
# get back the total number of indexed docs and return it
generator_progress_value = r.get(rci.generator_progress_key)
if generator_progress_value is not None:
try:
n_final_progress = int(cast(int, generator_progress_value))
except ValueError:
pass
r.set(rci.generator_complete_key, HTTPStatus.OK.value)
except Exception as e:
task_logger.exception(f"Failed to run indexing for cc_pair_id={cc_pair_id}.")
if attempt:
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
r.delete(rci.generator_lock_key)
r.delete(rci.generator_progress_key)
r.delete(rci.taskset_key)
r.delete(rci.fence_key)
raise e
finally:
if lock.owned():
lock.release()
return n_final_progress

View File

@@ -14,7 +14,7 @@ from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import PostgresAdvisoryLocks
from danswer.db.engine import get_sqlalchemy_engine # type: ignore
from danswer.db.engine import get_session_with_tenant
@shared_task(
@@ -23,7 +23,7 @@ from danswer.db.engine import get_sqlalchemy_engine # type: ignore
bind=True,
base=AbortableTask,
)
def kombu_message_cleanup_task(self: Any) -> int:
def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
"""Runs periodically to clean up the kombu_message table"""
# we will select messages older than this amount to clean up
@@ -35,7 +35,7 @@ def kombu_message_cleanup_task(self: Any) -> int:
ctx["deleted"] = 0
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
# Exit the task if we can't take the advisory lock
result = db_session.execute(
text("SELECT pg_try_advisory_lock(:id)"),

View File

@@ -3,7 +3,6 @@ from datetime import timedelta
from datetime import timezone
from uuid import uuid4
import redis
from celery import shared_task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
@@ -15,7 +14,9 @@ from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
@@ -24,18 +25,21 @@ from danswer.connectors.models import InputType
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
logger = setup_logger()
@shared_task(
name="check_for_prune_task_2",
name="check_for_pruning",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_prune_task_2() -> None:
r = get_redis_client()
def check_for_pruning(*, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_PRUNE_BEAT_LOCK,
@@ -47,16 +51,20 @@ def check_for_prune_task_2() -> None:
if not lock_beat.acquire(blocking=False):
return
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
tasks_created = ccpair_pruning_generator_task_creation_helper(
cc_pair, db_session, r, lock_beat
lock_beat.reacquire()
if not is_pruning_due(cc_pair, db_session, r):
continue
tasks_created = try_creating_prune_generator_task(
cc_pair, db_session, r, tenant_id
)
if not tasks_created:
continue
task_logger.info(f"Pruning started: cc_pair_id={cc_pair.id}")
task_logger.info(f"Pruning queued: cc_pair_id={cc_pair.id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -68,12 +76,11 @@ def check_for_prune_task_2() -> None:
lock_beat.release()
def ccpair_pruning_generator_task_creation_helper(
def is_pruning_due(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
) -> int | None:
) -> bool:
"""Returns an int if pruning is triggered.
The int represents the number of prune tasks generated (in this case, only one
because the task is a long running generator task.)
@@ -84,84 +91,137 @@ def ccpair_pruning_generator_task_creation_helper(
try_creating_prune_generator_task.
"""
lock_beat.reacquire()
# skip pruning if no prune frequency is set
# pruning can still be forced via the API which will run a pruning task directly
if not cc_pair.connector.prune_freq:
return None
return False
# skip pruning if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
# skip pruning if the next scheduled prune time hasn't been reached yet
last_pruned = cc_pair.last_pruned
if not last_pruned:
# if never pruned, use the connector time created as the last_pruned time
last_pruned = cc_pair.connector.time_created
if not cc_pair.last_successful_index_time:
# if we've never indexed, we can't prune
return False
# if never pruned, use the last time the connector indexed successfully
last_pruned = cc_pair.last_successful_index_time
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
if datetime.now(timezone.utc) < next_prune:
return None
return False
return try_creating_prune_generator_task(cc_pair, db_session, r)
return True
def try_creating_prune_generator_task(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
tenant_id: str | None,
) -> int | None:
"""Checks for any conditions that should block the pruning generator task from being
created, then creates the task.
Does not check for scheduling related conditions as this function
is used to trigger prunes immediately.
is used to trigger prunes immediately, e.g. via the web ui.
"""
if not ALLOW_SIMULTANEOUS_PRUNING:
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
return None
rcp = RedisConnectorPruning(cc_pair.id)
LOCK_TIMEOUT = 30
# skip pruning if already pruning
if r.exists(rcp.fence_key):
return None
# skip pruning if the cc_pair is deleting
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# add a long running generator task to the queue
r.delete(rcp.generator_complete_key)
r.delete(rcp.taskset_key)
custom_task_id = f"{rcp.generator_task_id_prefix}_{uuid4()}"
celery_app.send_task(
"connector_pruning_generator_task",
kwargs=dict(
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
),
queue=DanswerCeleryQueues.CONNECTOR_PRUNING,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
# we need to serialize starting pruning since it can be triggered either via
# celery beat or manually (API call)
lock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_prune_generator_task",
timeout=LOCK_TIMEOUT,
)
# set this only after all tasks have been added
r.set(rcp.fence_key, 1)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
try:
rcp = RedisConnectorPruning(cc_pair.id)
# skip pruning if already pruning
if r.exists(rcp.fence_key):
return None
# skip pruning if the cc_pair is deleting
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# add a long running generator task to the queue
r.delete(rcp.generator_complete_key)
r.delete(rcp.taskset_key)
custom_task_id = f"{rcp.generator_task_id_prefix}_{uuid4()}"
celery_app.send_task(
"connector_pruning_generator_task",
kwargs=dict(
cc_pair_id=cc_pair.id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_PRUNING,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
# set this only after all tasks have been added
r.set(rcp.fence_key, 1)
except Exception:
task_logger.exception("Unexpected exception")
return None
finally:
if lock.owned():
lock.release()
return 1
@shared_task(name="connector_pruning_generator_task", soft_time_limit=JOB_TIMEOUT)
def connector_pruning_generator_task(connector_id: int, credential_id: int) -> None:
@shared_task(
name="connector_pruning_generator_task",
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,
trail=False,
)
def connector_pruning_generator_task(
cc_pair_id: int, connector_id: int, credential_id: int, tenant_id: str | None
) -> None:
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
with Session(get_sqlalchemy_engine()) as db_session:
try:
rcp = RedisConnectorPruning(cc_pair_id)
lock = r.lock(
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{rcp._id}",
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"Pruning task already running, exiting...: cc_pair_id={cc_pair_id}"
)
return None
try:
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
@@ -170,14 +230,13 @@ def connector_pruning_generator_task(connector_id: int, credential_id: int) -> N
if not cc_pair:
task_logger.warning(
f"ccpair not found for {connector_id} {credential_id}"
f"cc_pair not found for {connector_id} {credential_id}"
)
return
rcp = RedisConnectorPruning(cc_pair.id)
# Define the callback function
def redis_increment_callback(amount: int) -> None:
lock.reacquire()
r.incrby(rcp.generator_progress_key, amount)
runnable_connector = instantiate_connector(
@@ -218,7 +277,9 @@ def connector_pruning_generator_task(connector_id: int, credential_id: int) -> N
task_logger.info(
f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}"
)
tasks_generated = rcp.generate_tasks(celery_app, db_session, r, None)
tasks_generated = rcp.generate_tasks(
celery_app, db_session, r, None, tenant_id
)
if tasks_generated is None:
return None
@@ -228,12 +289,13 @@ def connector_pruning_generator_task(connector_id: int, credential_id: int) -> N
)
r.set(rcp.generator_complete_key, tasks_generated)
except Exception as e:
task_logger.exception(
f"Failed to run pruning for connector id {connector_id}."
)
except Exception as e:
task_logger.exception(f"Failed to run pruning for connector id {connector_id}.")
r.delete(rcp.generator_progress_key)
r.delete(rcp.taskset_key)
r.delete(rcp.fence_key)
raise e
r.delete(rcp.generator_progress_key)
r.delete(rcp.taskset_key)
r.delete(rcp.fence_key)
raise e
finally:
if lock.owned():
lock.release()

View File

@@ -1,7 +1,9 @@
from datetime import datetime
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from sqlalchemy.orm import Session
from pydantic import BaseModel
from danswer.access.access import get_access_for_document
from danswer.background.celery.celery_app import task_logger
@@ -11,13 +13,20 @@ from danswer.db.document import get_document
from danswer.db.document import get_document_connector_count
from danswer.db.document import mark_document_as_synced
from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import get_session_with_tenant
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
class RedisConnectorIndexingFenceData(BaseModel):
index_attempt_id: int
started: datetime | None
submitted: datetime
celery_task_id: str
@shared_task(
name="document_by_cc_pair_cleanup_task",
bind=True,
@@ -26,7 +35,11 @@ from danswer.server.documents.models import ConnectorCredentialPairIdentifier
max_retries=3,
)
def document_by_cc_pair_cleanup_task(
self: Task, document_id: str, connector_id: int, credential_id: int
self: Task,
document_id: str,
connector_id: int,
credential_id: int,
tenant_id: str | None,
) -> bool:
"""A lightweight subtask used to clean up document to cc pair relationships.
Created by connection deletion and connector pruning parent tasks."""
@@ -46,7 +59,10 @@ def document_by_cc_pair_cleanup_task(
task_logger.info(f"document_id={document_id}")
try:
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
action = "skip"
chunks_affected = 0
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
@@ -56,12 +72,16 @@ def document_by_cc_pair_cleanup_task(
if count == 1:
# count == 1 means this is the only remaining cc_pair reference to the doc
# delete it from vespa and the db
document_index.delete(doc_ids=[document_id])
action = "delete"
chunks_affected = document_index.delete_single(document_id)
delete_documents_complete__no_commit(
db_session=db_session,
document_ids=[document_id],
)
elif count > 1:
action = "update"
# count > 1 means the document still has cc_pair references
doc = get_document(document_id, db_session)
if not doc:
@@ -84,7 +104,9 @@ def document_by_cc_pair_cleanup_task(
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
document_index.update_single(document_id, fields=fields)
chunks_affected = document_index.update_single(
document_id, fields=fields
)
# there are still other cc_pair references to the doc, so just resync to Vespa
delete_document_by_connector_credential_pair__no_commit(
@@ -100,9 +122,18 @@ def document_by_cc_pair_cleanup_task(
else:
pass
task_logger.info(
f"tenant_id={tenant_id} "
f"document_id={document_id} "
f"action={action} "
f"refcount={count} "
f"chunks={chunks_affected}"
)
db_session.commit()
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
task_logger.info(
f"SoftTimeLimitExceeded exception. tenant_id={tenant_id} doc_id={document_id}"
)
except Exception as e:
task_logger.exception("Unexpected exception")

View File

@@ -1,23 +1,32 @@
import traceback
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from typing import cast
import redis
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from celery.result import AsyncResult
from celery.states import READY_STATES
from redis import Redis
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_document
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.celery_redis import celery_get_queue_length
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.tasks.shared.tasks import RedisConnectorIndexingFenceData
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector import mark_ccpair_as_pruned
@@ -29,6 +38,7 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair_f
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.document import count_documents_by_needs_sync
from danswer.db.document import get_document
from danswer.db.document import get_document_ids_for_connector_credential_pair
from danswer.db.document import mark_document_as_synced
from danswer.db.document_set import delete_document_set
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
@@ -36,20 +46,29 @@ from danswer.db.document_set import fetch_document_sets
from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import IndexingStatus
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import DocumentSet
from danswer.db.models import IndexAttempt
from danswer.db.models import UserGroup
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import UpdateRequest
from danswer.document_index.interfaces import VespaDocumentFields
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import noop_fallback
logger = setup_logger()
# celery auto associates tasks created inside another task,
# which bloats the result metadata considerably. trail=False prevents this.
@@ -58,11 +77,11 @@ from danswer.utils.variable_functionality import noop_fallback
soft_time_limit=JOB_TIMEOUT,
trail=False,
)
def check_for_vespa_sync_task() -> None:
def check_for_vespa_sync_task(*, tenant_id: str | None) -> None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
@@ -74,8 +93,8 @@ def check_for_vespa_sync_task() -> None:
if not lock_beat.acquire(blocking=False):
return
with Session(get_sqlalchemy_engine()) as db_session:
try_generate_stale_document_sync_tasks(db_session, r, lock_beat)
with get_session_with_tenant(tenant_id) as db_session:
try_generate_stale_document_sync_tasks(db_session, r, lock_beat, tenant_id)
# check if any document sets are not synced
document_set_info = fetch_document_sets(
@@ -83,25 +102,28 @@ def check_for_vespa_sync_task() -> None:
)
for document_set, _ in document_set_info:
try_generate_document_set_sync_tasks(
document_set, db_session, r, lock_beat
document_set, db_session, r, lock_beat, tenant_id
)
# check if any user groups are not synced
try:
fetch_user_groups = fetch_versioned_implementation(
"danswer.db.user_group", "fetch_user_groups"
)
user_groups = fetch_user_groups(
db_session=db_session, only_up_to_date=False
)
for usergroup in user_groups:
try_generate_user_group_sync_tasks(
usergroup, db_session, r, lock_beat
if global_version.is_ee_version():
try:
fetch_user_groups = fetch_versioned_implementation(
"danswer.db.user_group", "fetch_user_groups"
)
except ModuleNotFoundError:
# Always exceptions on the MIT version, which is expected
pass
user_groups = fetch_user_groups(
db_session=db_session, only_up_to_date=False
)
for usergroup in user_groups:
try_generate_user_group_sync_tasks(
usergroup, db_session, r, lock_beat, tenant_id
)
except ModuleNotFoundError:
# Always exceptions on the MIT version, which is expected
# We shouldn't actually get here if the ee version check works
pass
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -114,7 +136,7 @@ def check_for_vespa_sync_task() -> None:
def try_generate_stale_document_sync_tasks(
db_session: Session, r: Redis, lock_beat: redis.lock.Lock
db_session: Session, r: Redis, lock_beat: redis.lock.Lock, tenant_id: str | None
) -> int | None:
# the fence is up, do nothing
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
@@ -139,7 +161,9 @@ def try_generate_stale_document_sync_tasks(
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
rc = RedisConnectorCredentialPair(cc_pair.id)
tasks_generated = rc.generate_tasks(celery_app, db_session, r, lock_beat)
tasks_generated = rc.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
continue
@@ -163,7 +187,11 @@ def try_generate_stale_document_sync_tasks(
def try_generate_document_set_sync_tasks(
document_set: DocumentSet, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
document_set: DocumentSet,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
lock_beat.reacquire()
@@ -187,7 +215,9 @@ def try_generate_document_set_sync_tasks(
)
# Add all documents that need to be updated into the queue
tasks_generated = rds.generate_tasks(celery_app, db_session, r, lock_beat)
tasks_generated = rds.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
return None
@@ -208,7 +238,11 @@ def try_generate_document_set_sync_tasks(
def try_generate_user_group_sync_tasks(
usergroup: UserGroup, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
usergroup: UserGroup,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
lock_beat.reacquire()
@@ -230,7 +264,9 @@ def try_generate_user_group_sync_tasks(
task_logger.info(
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
)
tasks_generated = rug.generate_tasks(celery_app, db_session, r, lock_beat)
tasks_generated = rug.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
return None
@@ -275,11 +311,13 @@ def monitor_document_set_taskset(
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key)
if document_set_id is None:
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
if document_set_id_str is None:
task_logger.warning(f"could not parse document set id from {fence_key}")
return
document_set_id = int(document_set_id_str)
rds = RedisDocumentSet(document_set_id)
fence_value = r.get(rds.fence_key)
@@ -294,7 +332,8 @@ def monitor_document_set_taskset(
count = cast(int, r.scard(rds.taskset_key))
task_logger.info(
f"Document set sync progress: document_set_id={document_set_id} remaining={count} initial={initial_count}"
f"Document set sync progress: document_set_id={document_set_id} "
f"remaining={count} initial={initial_count}"
)
if count > 0:
return
@@ -320,13 +359,17 @@ def monitor_document_set_taskset(
r.delete(rds.fence_key)
def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
def monitor_connector_deletion_taskset(
key_bytes: bytes, r: Redis, tenant_id: str | None
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
if cc_pair_id is None:
cc_pair_id_str = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
return
cc_pair_id = int(cc_pair_id_str)
rcd = RedisConnectorDeletion(cc_pair_id)
fence_value = r.get(rcd.fence_key)
@@ -341,25 +384,36 @@ def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
count = cast(int, r.scard(rcd.taskset_key))
task_logger.info(
f"Connector deletion progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if not cc_pair:
task_logger.warning(
f"monitor_connector_deletion_taskset - cc_pair_id not found: cc_pair_id={cc_pair_id}"
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
)
return
try:
doc_ids = get_document_ids_for_connector_credential_pair(
db_session, cc_pair.connector_id, cc_pair.credential_id
)
if len(doc_ids) > 0:
# if this happens, documents somehow got added while deletion was in progress. Likely a bug
# gating off pruning and indexing work before deletion starts
task_logger.warning(
f"Connector deletion - documents still found after taskset completion: "
f"cc_pair={cc_pair_id} num={len(doc_ids)}"
)
# clean up the rest of the related Postgres entities
# index attempts
delete_index_attempts(
db_session=db_session,
cc_pair_id=cc_pair.id,
cc_pair_id=cc_pair_id,
)
# document sets
@@ -376,7 +430,7 @@ def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
noop_fallback,
)
cleanup_user_groups(
cc_pair_id=cc_pair.id,
cc_pair_id=cc_pair_id,
db_session=db_session,
)
@@ -398,20 +452,21 @@ def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
db_session.delete(connector)
db_session.commit()
except Exception as e:
db_session.rollback()
stack_trace = traceback.format_exc()
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
add_deletion_failure_message(db_session, cc_pair.id, error_message)
add_deletion_failure_message(db_session, cc_pair_id, error_message)
task_logger.exception(
f"Failed to run connector_deletion. "
f"cc_pair_id={cc_pair_id} connector_id={cc_pair.connector_id} credential_id={cc_pair.credential_id}"
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
)
raise e
task_logger.info(
f"Successfully deleted cc_pair: "
f"cc_pair_id={cc_pair_id} "
f"connector_id={cc_pair.connector_id} "
f"credential_id={cc_pair.credential_id} "
f"cc_pair={cc_pair_id} "
f"connector={cc_pair.connector_id} "
f"credential={cc_pair.credential_id} "
f"docs_deleted={initial_count}"
)
@@ -423,13 +478,15 @@ def monitor_ccpair_pruning_taskset(
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id = RedisConnectorPruning.get_id_from_fence_key(fence_key)
if cc_pair_id is None:
cc_pair_id_str = RedisConnectorPruning.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_connector_pruning_taskset: could not parse cc_pair_id from {fence_key}"
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
rcp = RedisConnectorPruning(cc_pair_id)
fence_value = r.get(rcp.fence_key)
@@ -453,7 +510,7 @@ def monitor_ccpair_pruning_taskset(
if count > 0:
return
mark_ccpair_as_pruned(cc_pair_id, db_session)
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
task_logger.info(
f"Successfully pruned connector credential pair. cc_pair_id={cc_pair_id}"
)
@@ -464,18 +521,131 @@ def monitor_ccpair_pruning_taskset(
r.delete(rcp.fence_key)
@shared_task(name="monitor_vespa_sync", soft_time_limit=300)
def monitor_vespa_sync() -> None:
def monitor_ccpair_indexing_taskset(
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
composite_id = RedisConnectorIndexing.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"monitor_ccpair_indexing_taskset: could not parse composite_id from {fence_key}"
)
return
# parse out metadata and initialize the helper class with it
parts = composite_id.split("/")
if len(parts) != 2:
return
cc_pair_id = int(parts[0])
search_settings_id = int(parts[1])
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
# read related data and evaluate/print task progress
fence_value = cast(bytes, r.get(rci.fence_key))
if fence_value is None:
return
try:
fence_json = fence_value.decode("utf-8")
fence_data = RedisConnectorIndexingFenceData.model_validate_json(
cast(str, fence_json)
)
except ValueError:
task_logger.exception(
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
)
raise
elapsed_submitted = datetime.now(timezone.utc) - fence_data.submitted
generator_progress_value = r.get(rci.generator_progress_key)
if generator_progress_value is not None:
try:
progress_count = int(cast(int, generator_progress_value))
task_logger.info(
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"progress={progress_count} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
except ValueError:
task_logger.error(
"monitor_ccpair_indexing_taskset: generator_progress_value is not an integer."
)
# Read result state BEFORE generator_complete_key to avoid a race condition
result: AsyncResult = AsyncResult(fence_data.celery_task_id)
result_state = result.state
generator_complete_value = r.get(rci.generator_complete_key)
if generator_complete_value is None:
if result_state in READY_STATES:
# IF the task state is READY, THEN generator_complete should be set
# if it isn't, then the worker crashed
task_logger.info(
f"Connector indexing aborted: "
f"cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
index_attempt = get_index_attempt(db_session, fence_data.index_attempt_id)
if index_attempt:
mark_attempt_failed(
index_attempt=index_attempt,
db_session=db_session,
failure_reason="Connector indexing aborted or exceptioned.",
)
r.delete(rci.generator_lock_key)
r.delete(rci.taskset_key)
r.delete(rci.generator_progress_key)
r.delete(rci.generator_complete_key)
r.delete(rci.fence_key)
return
status_enum = HTTPStatus.INTERNAL_SERVER_ERROR
try:
status_value = int(cast(int, generator_complete_value))
status_enum = HTTPStatus(status_value)
except ValueError:
task_logger.error(
f"monitor_ccpair_indexing_taskset: "
f"generator_complete_value=f{generator_complete_value} could not be parsed."
)
task_logger.info(
f"Connector indexing finished: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
r.delete(rci.generator_lock_key)
r.delete(rci.taskset_key)
r.delete(rci.generator_progress_key)
r.delete(rci.generator_complete_key)
r.delete(rci.fence_key)
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
It scans for fence values and then gets the counts of any associated tasksets.
If the count is 0, that means all tasks finished and we should clean up.
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
do anything too expensive in this function!
"""
r = get_redis_client()
lock_beat = r.lock(
Returns True if the task actually did work, False
"""
r = get_redis_client(tenant_id=tenant_id)
lock_beat: redis.lock.Lock = r.lock(
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -483,18 +653,46 @@ def monitor_vespa_sync() -> None:
try:
# prevent overlapping tasks
if not lock_beat.acquire(blocking=False):
return
return False
# print current queue lengths
r_celery = self.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r)
n_indexing = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery
)
n_sync = celery_get_queue_length(
DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery
)
n_deletion = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_DELETION, r_celery
)
n_pruning = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_PRUNING, r_celery
)
task_logger.info(
f"Queue lengths: celery={n_celery} "
f"indexing={n_indexing} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning}"
)
lock_beat.reacquire()
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
monitor_connector_deletion_taskset(key_bytes, r)
monitor_connector_deletion_taskset(key_bytes, r, tenant_id)
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
monitor_document_set_taskset(key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
monitor_usergroup_taskset = (
fetch_versioned_implementation_with_fallback(
@@ -505,9 +703,33 @@ def monitor_vespa_sync() -> None:
)
monitor_usergroup_taskset(key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
# do some cleanup before clearing fences
# check the db for any outstanding index attempts
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for a in attempts:
# if attempts exist in the db but we don't detect them in redis, mark them as failed
rci = RedisConnectorIndexing(
a.connector_credential_pair_id, a.search_settings_id
)
failure_reason = f"Unknown index attempt {a.id}. Might be left over from a process restart."
if not r.exists(rci.fence_key):
mark_attempt_failed(a, db_session, failure_reason=failure_reason)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
monitor_ccpair_indexing_taskset(key_bytes, r, db_session)
# uncomment for debugging if needed
# r_celery = celery_app.broker_connection().channel().client
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
@@ -520,6 +742,8 @@ def monitor_vespa_sync() -> None:
if lock_beat.owned():
lock_beat.release()
return True
@shared_task(
name="vespa_metadata_sync_task",
@@ -528,11 +752,13 @@ def monitor_vespa_sync() -> None:
time_limit=60,
max_retries=3,
)
def vespa_metadata_sync_task(self: Task, document_id: str) -> bool:
def vespa_metadata_sync_task(
self: Task, document_id: str, tenant_id: str | None
) -> bool:
task_logger.info(f"document_id={document_id}")
try:
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
@@ -550,20 +776,24 @@ def vespa_metadata_sync_task(self: Task, document_id: str) -> bool:
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
update_request = UpdateRequest(
document_ids=[document_id],
fields = VespaDocumentFields(
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# update Vespa
document_index.update(update_requests=[update_request])
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
chunks_affected = document_index.update_single(document_id, fields=fields)
# update db last. Worst case = we crash right before this and
# the sync might repeat again later
mark_document_as_synced(document_id, db_session)
task_logger.info(
f"document_id={document_id} action=sync chunks={chunks_affected}"
)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
except Exception as e:

View File

@@ -1,5 +1,6 @@
import time
import traceback
from collections.abc import Callable
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -17,13 +18,12 @@ from danswer.connectors.models import IndexAttemptMetadata
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_in_progress
from danswer.db.index_attempt import mark_attempt_partially_succeeded
from danswer.db.index_attempt import mark_attempt_succeeded
from danswer.db.index_attempt import transition_attempt_to_in_progress
from danswer.db.index_attempt import update_docs_indexed
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
@@ -46,6 +46,7 @@ def _get_connector_runner(
attempt: IndexAttempt,
start_time: datetime,
end_time: datetime,
tenant_id: str | None,
) -> ConnectorRunner:
"""
NOTE: `start_time` and `end_time` are only used for poll connectors
@@ -63,6 +64,7 @@ def _get_connector_runner(
input_type=task,
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
credential=attempt.connector_credential_pair.credential,
tenant_id=tenant_id,
)
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")
@@ -89,11 +91,16 @@ def _get_connector_runner(
def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
tenant_id: str | None,
progress_callback: Callable[[int], None] | None = None,
) -> None:
"""
1. Get documents which are either new or updated from specified application
2. Embed and index these documents into the chosen datastore (vespa)
3. Updates Postgres to record the indexed documents + the outcome of this run
TODO: do not change index attempt statuses here ... instead, set signals in redis
and allow the monitor function to clean them up
"""
start_time = time.time()
@@ -129,6 +136,7 @@ def _run_indexing(
or (search_settings.status == IndexModelStatus.FUTURE)
),
db_session=db_session,
tenant_id=tenant_id,
)
db_cc_pair = index_attempt.connector_credential_pair
@@ -185,6 +193,7 @@ def _run_indexing(
attempt=index_attempt,
start_time=window_start,
end_time=window_end,
tenant_id=tenant_id,
)
all_connector_doc_ids: set[str] = set()
@@ -197,7 +206,7 @@ def _run_indexing(
# index being built. We want to populate it even for paused connectors
# Often paused connectors are sources that aren't updated frequently but the
# contents still need to be initially pulled.
db_session.refresh(db_connector)
db_session.refresh(db_cc_pair)
if (
(
db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
@@ -212,7 +221,9 @@ def _run_indexing(
db_session.refresh(index_attempt)
if index_attempt.status != IndexingStatus.IN_PROGRESS:
# Likely due to user manually disabling it or model swap
raise RuntimeError("Index Attempt was canceled")
raise RuntimeError(
f"Index Attempt was canceled, status is {index_attempt.status}"
)
batch_description = []
for doc in doc_batch:
@@ -232,6 +243,8 @@ def _run_indexing(
logger.debug(f"Indexing batch of documents: {batch_description}")
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
# real work happens here!
new_docs, total_batch_chunks = indexing_pipeline(
document_batch=doc_batch,
index_attempt_metadata=index_attempt_md,
@@ -250,6 +263,9 @@ def _run_indexing(
# be inaccurate
db_session.commit()
if progress_callback:
progress_callback(len(doc_batch))
# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(
db_session=db_session,
@@ -373,40 +389,13 @@ def _run_indexing(
)
def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexAttempt:
# make sure that the index attempt can't change in between checking the
# status and marking it as in_progress. This setting will be discarded
# after the next commit:
# https://docs.sqlalchemy.org/en/20/orm/session_transaction.html#setting-isolation-for-individual-transactions
db_session.connection(execution_options={"isolation_level": "SERIALIZABLE"}) # type: ignore
attempt = get_index_attempt(
db_session=db_session,
index_attempt_id=index_attempt_id,
)
if attempt is None:
raise RuntimeError(f"Unable to find IndexAttempt for ID '{index_attempt_id}'")
if attempt.status != IndexingStatus.NOT_STARTED:
raise RuntimeError(
f"Indexing attempt with ID '{index_attempt_id}' is not in NOT_STARTED status. "
f"Current status is '{attempt.status}'."
)
# only commit once, to make sure this all happens in a single transaction
mark_attempt_in_progress(attempt, db_session)
return attempt
def run_indexing_entrypoint(
index_attempt_id: int, connector_credential_pair_id: int, is_ee: bool = False
index_attempt_id: int,
tenant_id: str | None,
connector_credential_pair_id: int,
is_ee: bool = False,
progress_callback: Callable[[int], None] | None = None,
) -> None:
"""Entrypoint for indexing run when using dask distributed.
Wraps the actual logic in a `try` block so that we can catch any exceptions
and mark the attempt as failed."""
try:
if is_ee:
global_version.set_ee()
@@ -416,26 +405,29 @@ def run_indexing_entrypoint(
IndexAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
)
with Session(get_sqlalchemy_engine()) as db_session:
# make sure that it is valid to run this indexing attempt + mark it
# as in progress
attempt = _prepare_index_attempt(db_session, index_attempt_id)
with get_session_with_tenant(tenant_id) as db_session:
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
logger.info(
f"Indexing starting: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"Indexing starting for tenant {tenant_id}: "
if tenant_id is not None
else ""
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
_run_indexing(db_session, attempt)
_run_indexing(db_session, attempt, tenant_id, progress_callback)
logger.info(
f"Indexing finished: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"Indexing finished for tenant {tenant_id}: "
if tenant_id is not None
else ""
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
except Exception as e:
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")
logger.exception(
f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}"
)

View File

@@ -1,495 +1,494 @@
import logging
import time
from datetime import datetime
import dask
from dask.distributed import Client
from dask.distributed import Future
from distributed import LocalCluster
from sqlalchemy.orm import Session
from danswer.background.indexing.dask_utils import ResourceLogger
from danswer.background.indexing.job_client import SimpleJob
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
from danswer.db.connector import fetch_connectors
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import SqlEngine
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_inprogress_index_attempts
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import get_not_started_index_attempts
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import LOG_LEVEL
from shared_configs.configs import MODEL_SERVER_PORT
logger = setup_logger()
# If the indexing dies, it's most likely due to resource constraints,
# restarting just delays the eventual failure, not useful to the user
dask.config.set({"distributed.scheduler.allowed-failures": 0})
_UNEXPECTED_STATE_FAILURE_REASON = (
"Stopped mid run, likely due to the background process being killed"
)
def _should_create_new_indexing(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
secondary_index_building: bool,
db_session: Session,
) -> bool:
connector = cc_pair.connector
# don't kick off indexing for `NOT_APPLICABLE` sources
if connector.source == DocumentSource.NOT_APPLICABLE:
return False
# User can still manually create single indexing attempts via the UI for the
# currently in use index
if DISABLE_INDEX_UPDATE_ON_SWAP:
if (
search_settings_instance.status == IndexModelStatus.PRESENT
and secondary_index_building
):
return False
# When switching over models, always index at least once
if search_settings_instance.status == IndexModelStatus.FUTURE:
if last_index:
# No new index if the last index attempt succeeded
# Once is enough. The model will never be able to swap otherwise.
if last_index.status == IndexingStatus.SUCCESS:
return False
# No new index if the last index attempt is waiting to start
if last_index.status == IndexingStatus.NOT_STARTED:
return False
# No new index if the last index attempt is running
if last_index.status == IndexingStatus.IN_PROGRESS:
return False
else:
if (
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
): # Ingestion API
return False
return True
# If the connector is paused or is the ingestion API, don't index
# NOTE: during an embedding model switch over, the following logic
# is bypassed by the above check for a future model
if (
not cc_pair.status.is_active()
or connector.id == 0
or connector.source == DocumentSource.INGESTION_API
):
return False
if not last_index:
return True
if connector.refresh_freq is None:
return False
# Only one scheduled/ongoing job per connector at a time
# this prevents cases where
# (1) the "latest" index_attempt is scheduled so we show
# that in the UI despite another index_attempt being in-progress
# (2) multiple scheduled index_attempts at a time
if (
last_index.status == IndexingStatus.NOT_STARTED
or last_index.status == IndexingStatus.IN_PROGRESS
):
return False
current_db_time = get_db_current_time(db_session)
time_since_index = current_db_time - last_index.time_updated
return time_since_index.total_seconds() >= connector.refresh_freq
def _mark_run_failed(
db_session: Session, index_attempt: IndexAttempt, failure_reason: str
) -> None:
"""Marks the `index_attempt` row as failed + updates the `
connector_credential_pair` to reflect that the run failed"""
logger.warning(
f"Marking in-progress attempt 'connector: {index_attempt.connector_credential_pair.connector_id}, "
f"credential: {index_attempt.connector_credential_pair.credential_id}' as failed due to {failure_reason}"
)
mark_attempt_failed(
index_attempt=index_attempt,
db_session=db_session,
failure_reason=failure_reason,
)
"""Main funcs"""
def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
"""Creates new indexing jobs for each connector / credential pair which is:
1. Enabled
2. `refresh_frequency` time has passed since the last indexing run for this pair
3. There is not already an ongoing indexing attempt for this pair
"""
with Session(get_sqlalchemy_engine()) as db_session:
ongoing: set[tuple[int | None, int]] = set()
for attempt_id in existing_jobs:
attempt = get_index_attempt(
db_session=db_session, index_attempt_id=attempt_id
)
if attempt is None:
logger.error(
f"Unable to find IndexAttempt for ID '{attempt_id}' when creating "
"indexing jobs"
)
continue
ongoing.add(
(
attempt.connector_credential_pair_id,
attempt.search_settings_id,
)
)
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings = [primary_search_settings]
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings.append(secondary_search_settings)
all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair in all_connector_credential_pairs:
for search_settings_instance in search_settings:
# Check if there is an ongoing indexing attempt for this connector credential pair
if (cc_pair.id, search_settings_instance.id) in ongoing:
continue
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, search_settings_instance.id, db_session
)
if not _should_create_new_indexing(
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
secondary_index_building=len(search_settings) > 1,
db_session=db_session,
):
continue
create_index_attempt(
cc_pair.id, search_settings_instance.id, db_session
)
def cleanup_indexing_jobs(
existing_jobs: dict[int, Future | SimpleJob],
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
) -> dict[int, Future | SimpleJob]:
existing_jobs_copy = existing_jobs.copy()
# clean up completed jobs
with Session(get_sqlalchemy_engine()) as db_session:
for attempt_id, job in existing_jobs.items():
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=attempt_id
)
# do nothing for ongoing jobs that haven't been stopped
if not job.done():
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
if job.status == "error":
logger.error(job.exception())
job.release()
del existing_jobs_copy[attempt_id]
if not index_attempt:
logger.error(
f"Unable to find IndexAttempt for ID '{attempt_id}' when cleaning "
"up indexing jobs"
)
continue
if (
index_attempt.status == IndexingStatus.IN_PROGRESS
or job.status == "error"
):
_mark_run_failed(
db_session=db_session,
index_attempt=index_attempt,
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
)
# clean up in-progress jobs that were never completed
connectors = fetch_connectors(db_session)
for connector in connectors:
in_progress_indexing_attempts = get_inprogress_index_attempts(
connector.id, db_session
)
for index_attempt in in_progress_indexing_attempts:
if index_attempt.id in existing_jobs:
# If index attempt is canceled, stop the run
if index_attempt.status == IndexingStatus.FAILED:
existing_jobs[index_attempt.id].cancel()
# check to see if the job has been updated in last `timeout_hours` hours, if not
# assume it to frozen in some bad state and just mark it as failed. Note: this relies
# on the fact that the `time_updated` field is constantly updated every
# batch of documents indexed
current_db_time = get_db_current_time(db_session=db_session)
time_since_update = current_db_time - index_attempt.time_updated
if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
existing_jobs[index_attempt.id].cancel()
_mark_run_failed(
db_session=db_session,
index_attempt=index_attempt,
failure_reason="Indexing run frozen - no updates in the last three hours. "
"The run will be re-attempted at next scheduled indexing time.",
)
else:
# If job isn't known, simply mark it as failed
_mark_run_failed(
db_session=db_session,
index_attempt=index_attempt,
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
)
return existing_jobs_copy
def kickoff_indexing_jobs(
existing_jobs: dict[int, Future | SimpleJob],
client: Client | SimpleJobClient,
secondary_client: Client | SimpleJobClient,
) -> dict[int, Future | SimpleJob]:
existing_jobs_copy = existing_jobs.copy()
engine = get_sqlalchemy_engine()
# Don't include jobs waiting in the Dask queue that just haven't started running
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
with Session(engine) as db_session:
# get_not_started_index_attempts orders its returned results from oldest to newest
# we must process attempts in a FIFO manner to prevent connector starvation
new_indexing_attempts = [
(attempt, attempt.search_settings)
for attempt in get_not_started_index_attempts(db_session)
if attempt.id not in existing_jobs
]
logger.debug(f"Found {len(new_indexing_attempts)} new indexing task(s).")
if not new_indexing_attempts:
return existing_jobs
indexing_attempt_count = 0
primary_client_full = False
secondary_client_full = False
for attempt, search_settings in new_indexing_attempts:
if primary_client_full and secondary_client_full:
break
use_secondary_index = (
search_settings.status == IndexModelStatus.FUTURE
if search_settings is not None
else False
)
if attempt.connector_credential_pair.connector is None:
logger.warning(
f"Skipping index attempt as Connector has been deleted: {attempt}"
)
with Session(engine) as db_session:
mark_attempt_failed(
attempt, db_session, failure_reason="Connector is null"
)
continue
if attempt.connector_credential_pair.credential is None:
logger.warning(
f"Skipping index attempt as Credential has been deleted: {attempt}"
)
with Session(engine) as db_session:
mark_attempt_failed(
attempt, db_session, failure_reason="Credential is null"
)
continue
if not use_secondary_index:
if not primary_client_full:
run = client.submit(
run_indexing_entrypoint,
attempt.id,
attempt.connector_credential_pair_id,
global_version.get_is_ee_version(),
pure=False,
)
if not run:
primary_client_full = True
else:
if not secondary_client_full:
run = secondary_client.submit(
run_indexing_entrypoint,
attempt.id,
attempt.connector_credential_pair_id,
global_version.get_is_ee_version(),
pure=False,
)
if not run:
secondary_client_full = True
if run:
if indexing_attempt_count == 0:
logger.info(
f"Indexing dispatch starts: pending={len(new_indexing_attempts)}"
)
indexing_attempt_count += 1
secondary_str = " (secondary index)" if use_secondary_index else ""
logger.info(
f"Indexing dispatched{secondary_str}: "
f"attempt_id={attempt.id} "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.credential_id}'"
)
existing_jobs_copy[attempt.id] = run
if indexing_attempt_count > 0:
logger.info(
f"Indexing dispatch results: "
f"initial_pending={len(new_indexing_attempts)} "
f"started={indexing_attempt_count} "
f"remaining={len(new_indexing_attempts) - indexing_attempt_count}"
)
return existing_jobs_copy
def update_loop(
delay: int = 10,
num_workers: int = NUM_INDEXING_WORKERS,
num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
) -> None:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
check_index_swap(db_session=db_session)
search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
if search_settings.provider_type is None:
logger.notice("Running a first inference to warm up embedding model")
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(
embedding_model=embedding_model,
)
logger.notice("First inference complete.")
client_primary: Client | SimpleJobClient
client_secondary: Client | SimpleJobClient
if DASK_JOB_CLIENT_ENABLED:
cluster_primary = LocalCluster(
n_workers=num_workers,
threads_per_worker=1,
# there are warning about high memory usage + "Event loop unresponsive"
# which are not relevant to us since our workers are expected to use a
# lot of memory + involve CPU intensive tasks that will not relinquish
# the event loop
silence_logs=logging.ERROR,
)
cluster_secondary = LocalCluster(
n_workers=num_secondary_workers,
threads_per_worker=1,
silence_logs=logging.ERROR,
)
client_primary = Client(cluster_primary)
client_secondary = Client(cluster_secondary)
if LOG_LEVEL.lower() == "debug":
client_primary.register_worker_plugin(ResourceLogger())
else:
client_primary = SimpleJobClient(n_workers=num_workers)
client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
existing_jobs: dict[int, Future | SimpleJob] = {}
logger.notice("Startup complete. Waiting for indexing jobs...")
while True:
start = time.time()
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
logger.debug(f"Running update, current UTC time: {start_time_utc}")
if existing_jobs:
# TODO: make this debug level once the "no jobs are being scheduled" issue is resolved
logger.debug(
"Found existing indexing jobs: "
f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}"
)
try:
with Session(get_sqlalchemy_engine()) as db_session:
check_index_swap(db_session)
existing_jobs = cleanup_indexing_jobs(existing_jobs=existing_jobs)
create_indexing_jobs(existing_jobs=existing_jobs)
existing_jobs = kickoff_indexing_jobs(
existing_jobs=existing_jobs,
client=client_primary,
secondary_client=client_secondary,
)
except Exception as e:
logger.exception(f"Failed to run update due to {e}")
sleep_time = delay - (time.time() - start)
if sleep_time > 0:
time.sleep(sleep_time)
def update__main() -> None:
set_is_ee_based_on_env_variable()
# initialize the Postgres connection pool
SqlEngine.set_app_name(POSTGRES_INDEXER_APP_NAME)
logger.notice("Starting indexing service")
update_loop()
if __name__ == "__main__":
update__main()
# TODO(rkuo): delete after background indexing via celery is fully vetted
# import logging
# import time
# from datetime import datetime
# import dask
# from dask.distributed import Client
# from dask.distributed import Future
# from distributed import LocalCluster
# from sqlalchemy import text
# from sqlalchemy.exc import ProgrammingError
# from sqlalchemy.orm import Session
# from danswer.background.indexing.dask_utils import ResourceLogger
# from danswer.background.indexing.job_client import SimpleJob
# from danswer.background.indexing.job_client import SimpleJobClient
# from danswer.background.indexing.run_indexing import run_indexing_entrypoint
# from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
# from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
# from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
# from danswer.configs.app_configs import MULTI_TENANT
# from danswer.configs.app_configs import NUM_INDEXING_WORKERS
# from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
# from danswer.configs.constants import DocumentSource
# from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
# from danswer.configs.constants import TENANT_ID_PREFIX
# from danswer.db.connector import fetch_connectors
# from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
# from danswer.db.engine import get_db_current_time
# from danswer.db.engine import get_session_with_tenant
# from danswer.db.engine import get_sqlalchemy_engine
# from danswer.db.engine import SqlEngine
# from danswer.db.index_attempt import create_index_attempt
# from danswer.db.index_attempt import get_index_attempt
# from danswer.db.index_attempt import get_inprogress_index_attempts
# from danswer.db.index_attempt import get_last_attempt_for_cc_pair
# from danswer.db.index_attempt import get_not_started_index_attempts
# from danswer.db.index_attempt import mark_attempt_failed
# from danswer.db.models import ConnectorCredentialPair
# from danswer.db.models import IndexAttempt
# from danswer.db.models import IndexingStatus
# from danswer.db.models import IndexModelStatus
# from danswer.db.models import SearchSettings
# from danswer.db.search_settings import get_current_search_settings
# from danswer.db.search_settings import get_secondary_search_settings
# from danswer.db.swap_index import check_index_swap
# from danswer.document_index.vespa.index import VespaIndex
# from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
# from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
# from danswer.utils.logger import setup_logger
# from danswer.utils.variable_functionality import global_version
# from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
# from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
# from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
# from shared_configs.configs import LOG_LEVEL
# logger = setup_logger()
# # If the indexing dies, it's most likely due to resource constraints,
# # restarting just delays the eventual failure, not useful to the user
# dask.config.set({"distributed.scheduler.allowed-failures": 0})
# _UNEXPECTED_STATE_FAILURE_REASON = (
# "Stopped mid run, likely due to the background process being killed"
# )
# def _should_create_new_indexing(
# cc_pair: ConnectorCredentialPair,
# last_index: IndexAttempt | None,
# search_settings_instance: SearchSettings,
# secondary_index_building: bool,
# db_session: Session,
# ) -> bool:
# connector = cc_pair.connector
# # don't kick off indexing for `NOT_APPLICABLE` sources
# if connector.source == DocumentSource.NOT_APPLICABLE:
# return False
# # User can still manually create single indexing attempts via the UI for the
# # currently in use index
# if DISABLE_INDEX_UPDATE_ON_SWAP:
# if (
# search_settings_instance.status == IndexModelStatus.PRESENT
# and secondary_index_building
# ):
# return False
# # When switching over models, always index at least once
# if search_settings_instance.status == IndexModelStatus.FUTURE:
# if last_index:
# # No new index if the last index attempt succeeded
# # Once is enough. The model will never be able to swap otherwise.
# if last_index.status == IndexingStatus.SUCCESS:
# return False
# # No new index if the last index attempt is waiting to start
# if last_index.status == IndexingStatus.NOT_STARTED:
# return False
# # No new index if the last index attempt is running
# if last_index.status == IndexingStatus.IN_PROGRESS:
# return False
# else:
# if (
# connector.id == 0 or connector.source == DocumentSource.INGESTION_API
# ): # Ingestion API
# return False
# return True
# # If the connector is paused or is the ingestion API, don't index
# # NOTE: during an embedding model switch over, the following logic
# # is bypassed by the above check for a future model
# if (
# not cc_pair.status.is_active()
# or connector.id == 0
# or connector.source == DocumentSource.INGESTION_API
# ):
# return False
# if not last_index:
# return True
# if connector.refresh_freq is None:
# return False
# # Only one scheduled/ongoing job per connector at a time
# # this prevents cases where
# # (1) the "latest" index_attempt is scheduled so we show
# # that in the UI despite another index_attempt being in-progress
# # (2) multiple scheduled index_attempts at a time
# if (
# last_index.status == IndexingStatus.NOT_STARTED
# or last_index.status == IndexingStatus.IN_PROGRESS
# ):
# return False
# current_db_time = get_db_current_time(db_session)
# time_since_index = current_db_time - last_index.time_updated
# return time_since_index.total_seconds() >= connector.refresh_freq
# def _mark_run_failed(
# db_session: Session, index_attempt: IndexAttempt, failure_reason: str
# ) -> None:
# """Marks the `index_attempt` row as failed + updates the `
# connector_credential_pair` to reflect that the run failed"""
# logger.warning(
# f"Marking in-progress attempt 'connector: {index_attempt.connector_credential_pair.connector_id}, "
# f"credential: {index_attempt.connector_credential_pair.credential_id}' as failed due to {failure_reason}"
# )
# mark_attempt_failed(
# index_attempt=index_attempt,
# db_session=db_session,
# failure_reason=failure_reason,
# )
# """Main funcs"""
# def create_indexing_jobs(
# existing_jobs: dict[int, Future | SimpleJob], tenant_id: str | None
# ) -> None:
# """Creates new indexing jobs for each connector / credential pair which is:
# 1. Enabled
# 2. `refresh_frequency` time has passed since the last indexing run for this pair
# 3. There is not already an ongoing indexing attempt for this pair
# """
# with get_session_with_tenant(tenant_id) as db_session:
# ongoing: set[tuple[int | None, int]] = set()
# for attempt_id in existing_jobs:
# attempt = get_index_attempt(
# db_session=db_session, index_attempt_id=attempt_id
# )
# if attempt is None:
# logger.error(
# f"Unable to find IndexAttempt for ID '{attempt_id}' when creating "
# "indexing jobs"
# )
# continue
# ongoing.add(
# (
# attempt.connector_credential_pair_id,
# attempt.search_settings_id,
# )
# )
# # Get the primary search settings
# primary_search_settings = get_current_search_settings(db_session)
# search_settings = [primary_search_settings]
# # Check for secondary search settings
# secondary_search_settings = get_secondary_search_settings(db_session)
# if secondary_search_settings is not None:
# # If secondary settings exist, add them to the list
# search_settings.append(secondary_search_settings)
# all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
# for cc_pair in all_connector_credential_pairs:
# for search_settings_instance in search_settings:
# # Check if there is an ongoing indexing attempt for this connector credential pair
# if (cc_pair.id, search_settings_instance.id) in ongoing:
# continue
# last_attempt = get_last_attempt_for_cc_pair(
# cc_pair.id, search_settings_instance.id, db_session
# )
# if not _should_create_new_indexing(
# cc_pair=cc_pair,
# last_index=last_attempt,
# search_settings_instance=search_settings_instance,
# secondary_index_building=len(search_settings) > 1,
# db_session=db_session,
# ):
# continue
# create_index_attempt(
# cc_pair.id, search_settings_instance.id, db_session
# )
# def cleanup_indexing_jobs(
# existing_jobs: dict[int, Future | SimpleJob],
# tenant_id: str | None,
# timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
# ) -> dict[int, Future | SimpleJob]:
# existing_jobs_copy = existing_jobs.copy()
# # clean up completed jobs
# with get_session_with_tenant(tenant_id) as db_session:
# for attempt_id, job in existing_jobs.items():
# index_attempt = get_index_attempt(
# db_session=db_session, index_attempt_id=attempt_id
# )
# # do nothing for ongoing jobs that haven't been stopped
# if not job.done():
# if not index_attempt:
# continue
# if not index_attempt.is_finished():
# continue
# if job.status == "error":
# logger.error(job.exception())
# job.release()
# del existing_jobs_copy[attempt_id]
# if not index_attempt:
# logger.error(
# f"Unable to find IndexAttempt for ID '{attempt_id}' when cleaning "
# "up indexing jobs"
# )
# continue
# if (
# index_attempt.status == IndexingStatus.IN_PROGRESS
# or job.status == "error"
# ):
# _mark_run_failed(
# db_session=db_session,
# index_attempt=index_attempt,
# failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
# )
# # clean up in-progress jobs that were never completed
# try:
# connectors = fetch_connectors(db_session)
# for connector in connectors:
# in_progress_indexing_attempts = get_inprogress_index_attempts(
# connector.id, db_session
# )
# for index_attempt in in_progress_indexing_attempts:
# if index_attempt.id in existing_jobs:
# # If index attempt is canceled, stop the run
# if index_attempt.status == IndexingStatus.FAILED:
# existing_jobs[index_attempt.id].cancel()
# # check to see if the job has been updated in last `timeout_hours` hours, if not
# # assume it to frozen in some bad state and just mark it as failed. Note: this relies
# # on the fact that the `time_updated` field is constantly updated every
# # batch of documents indexed
# current_db_time = get_db_current_time(db_session=db_session)
# time_since_update = current_db_time - index_attempt.time_updated
# if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
# existing_jobs[index_attempt.id].cancel()
# _mark_run_failed(
# db_session=db_session,
# index_attempt=index_attempt,
# failure_reason="Indexing run frozen - no updates in the last three hours. "
# "The run will be re-attempted at next scheduled indexing time.",
# )
# else:
# # If job isn't known, simply mark it as failed
# _mark_run_failed(
# db_session=db_session,
# index_attempt=index_attempt,
# failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
# )
# except ProgrammingError:
# logger.debug(f"No Connector Table exists for: {tenant_id}")
# return existing_jobs_copy
# def kickoff_indexing_jobs(
# existing_jobs: dict[int, Future | SimpleJob],
# client: Client | SimpleJobClient,
# secondary_client: Client | SimpleJobClient,
# tenant_id: str | None,
# ) -> dict[int, Future | SimpleJob]:
# existing_jobs_copy = existing_jobs.copy()
# current_session = get_session_with_tenant(tenant_id)
# # Don't include jobs waiting in the Dask queue that just haven't started running
# # Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
# with current_session as db_session:
# # get_not_started_index_attempts orders its returned results from oldest to newest
# # we must process attempts in a FIFO manner to prevent connector starvation
# new_indexing_attempts = [
# (attempt, attempt.search_settings)
# for attempt in get_not_started_index_attempts(db_session)
# if attempt.id not in existing_jobs
# ]
# logger.debug(f"Found {len(new_indexing_attempts)} new indexing task(s).")
# if not new_indexing_attempts:
# return existing_jobs
# indexing_attempt_count = 0
# primary_client_full = False
# secondary_client_full = False
# for attempt, search_settings in new_indexing_attempts:
# if primary_client_full and secondary_client_full:
# break
# use_secondary_index = (
# search_settings.status == IndexModelStatus.FUTURE
# if search_settings is not None
# else False
# )
# if attempt.connector_credential_pair.connector is None:
# logger.warning(
# f"Skipping index attempt as Connector has been deleted: {attempt}"
# )
# with current_session as db_session:
# mark_attempt_failed(
# attempt, db_session, failure_reason="Connector is null"
# )
# continue
# if attempt.connector_credential_pair.credential is None:
# logger.warning(
# f"Skipping index attempt as Credential has been deleted: {attempt}"
# )
# with current_session as db_session:
# mark_attempt_failed(
# attempt, db_session, failure_reason="Credential is null"
# )
# continue
# if not use_secondary_index:
# if not primary_client_full:
# run = client.submit(
# run_indexing_entrypoint,
# attempt.id,
# tenant_id,
# attempt.connector_credential_pair_id,
# global_version.is_ee_version(),
# pure=False,
# )
# if not run:
# primary_client_full = True
# else:
# if not secondary_client_full:
# run = secondary_client.submit(
# run_indexing_entrypoint,
# attempt.id,
# tenant_id,
# attempt.connector_credential_pair_id,
# global_version.is_ee_version(),
# pure=False,
# )
# if not run:
# secondary_client_full = True
# if run:
# if indexing_attempt_count == 0:
# logger.info(
# f"Indexing dispatch starts: pending={len(new_indexing_attempts)}"
# )
# indexing_attempt_count += 1
# secondary_str = " (secondary index)" if use_secondary_index else ""
# logger.info(
# f"Indexing dispatched{secondary_str}: "
# f"attempt_id={attempt.id} "
# f"connector='{attempt.connector_credential_pair.connector.name}' "
# f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
# f"credentials='{attempt.connector_credential_pair.credential_id}'"
# )
# existing_jobs_copy[attempt.id] = run
# if indexing_attempt_count > 0:
# logger.info(
# f"Indexing dispatch results: "
# f"initial_pending={len(new_indexing_attempts)} "
# f"started={indexing_attempt_count} "
# f"remaining={len(new_indexing_attempts) - indexing_attempt_count}"
# )
# return existing_jobs_copy
# def get_all_tenant_ids() -> list[str] | list[None]:
# if not MULTI_TENANT:
# return [None]
# with get_session_with_tenant(tenant_id="public") as session:
# result = session.execute(
# text(
# """
# SELECT schema_name
# FROM information_schema.schemata
# WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')"""
# )
# )
# tenant_ids = [row[0] for row in result]
# valid_tenants = [
# tenant
# for tenant in tenant_ids
# if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
# ]
# return valid_tenants
# def update_loop(
# delay: int = 10,
# num_workers: int = NUM_INDEXING_WORKERS,
# num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
# ) -> None:
# if not MULTI_TENANT:
# # We can use this function as we are certain only the public schema exists
# # (explicitly for the non-`MULTI_TENANT` case)
# engine = get_sqlalchemy_engine()
# with Session(engine) as db_session:
# check_index_swap(db_session=db_session)
# search_settings = get_current_search_settings(db_session)
# # So that the first time users aren't surprised by really slow speed of first
# # batch of documents indexed
# if search_settings.provider_type is None:
# logger.notice("Running a first inference to warm up embedding model")
# embedding_model = EmbeddingModel.from_db_model(
# search_settings=search_settings,
# server_host=INDEXING_MODEL_SERVER_HOST,
# server_port=INDEXING_MODEL_SERVER_PORT,
# )
# warm_up_bi_encoder(
# embedding_model=embedding_model,
# )
# logger.notice("First inference complete.")
# client_primary: Client | SimpleJobClient
# client_secondary: Client | SimpleJobClient
# if DASK_JOB_CLIENT_ENABLED:
# cluster_primary = LocalCluster(
# n_workers=num_workers,
# threads_per_worker=1,
# silence_logs=logging.ERROR,
# )
# cluster_secondary = LocalCluster(
# n_workers=num_secondary_workers,
# threads_per_worker=1,
# silence_logs=logging.ERROR,
# )
# client_primary = Client(cluster_primary)
# client_secondary = Client(cluster_secondary)
# if LOG_LEVEL.lower() == "debug":
# client_primary.register_worker_plugin(ResourceLogger())
# else:
# client_primary = SimpleJobClient(n_workers=num_workers)
# client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
# existing_jobs: dict[str | None, dict[int, Future | SimpleJob]] = {}
# logger.notice("Startup complete. Waiting for indexing jobs...")
# while True:
# start = time.time()
# start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
# logger.debug(f"Running update, current UTC time: {start_time_utc}")
# if existing_jobs:
# logger.debug(
# "Found existing indexing jobs: "
# f"{[(tenant_id, list(jobs.keys())) for tenant_id, jobs in existing_jobs.items()]}"
# )
# try:
# tenants = get_all_tenant_ids()
# for tenant_id in tenants:
# try:
# logger.debug(
# f"Processing {'index attempts' if tenant_id is None else f'tenant {tenant_id}'}"
# )
# with get_session_with_tenant(tenant_id) as db_session:
# index_to_expire = check_index_swap(db_session=db_session)
# if index_to_expire and tenant_id and MULTI_TENANT:
# VespaIndex.delete_entries_by_tenant_id(
# tenant_id=tenant_id,
# index_name=index_to_expire.index_name,
# )
# if not MULTI_TENANT:
# search_settings = get_current_search_settings(db_session)
# if search_settings.provider_type is None:
# logger.notice(
# "Running a first inference to warm up embedding model"
# )
# embedding_model = EmbeddingModel.from_db_model(
# search_settings=search_settings,
# server_host=INDEXING_MODEL_SERVER_HOST,
# server_port=INDEXING_MODEL_SERVER_PORT,
# )
# warm_up_bi_encoder(embedding_model=embedding_model)
# logger.notice("First inference complete.")
# tenant_jobs = existing_jobs.get(tenant_id, {})
# tenant_jobs = cleanup_indexing_jobs(
# existing_jobs=tenant_jobs, tenant_id=tenant_id
# )
# create_indexing_jobs(existing_jobs=tenant_jobs, tenant_id=tenant_id)
# tenant_jobs = kickoff_indexing_jobs(
# existing_jobs=tenant_jobs,
# client=client_primary,
# secondary_client=client_secondary,
# tenant_id=tenant_id,
# )
# existing_jobs[tenant_id] = tenant_jobs
# except Exception as e:
# logger.exception(
# f"Failed to process tenant {tenant_id or 'default'}: {e}"
# )
# except Exception as e:
# logger.exception(f"Failed to run update due to {e}")
# sleep_time = delay - (time.time() - start)
# if sleep_time > 0:
# time.sleep(sleep_time)
# def update__main() -> None:
# set_is_ee_based_on_env_variable()
# # initialize the Postgres connection pool
# SqlEngine.set_app_name(POSTGRES_INDEXER_APP_NAME)
# logger.notice("Starting indexing service")
# update_loop()
# if __name__ == "__main__":
# update__main()

View File

@@ -1,6 +1,8 @@
import re
from typing import cast
from uuid import UUID
from fastapi.datastructures import Headers
from sqlalchemy.orm import Session
from danswer.chat.models import CitationInfo
@@ -33,7 +35,7 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
def create_chat_chain(
chat_session_id: int,
chat_session_id: UUID,
db_session: Session,
prefetch_tool_calls: bool = True,
# Optional id at which we finish processing
@@ -166,3 +168,31 @@ def reorganize_citations(
new_citation_info[citation.citation_num] = citation
return new_answer, list(new_citation_info.values())
def extract_headers(
headers: dict[str, str] | Headers, pass_through_headers: list[str] | None
) -> dict[str, str]:
"""
Extract headers specified in pass_through_headers from input headers.
Handles both dict and FastAPI Headers objects, accounting for lowercase keys.
Args:
headers: Input headers as dict or Headers object.
Returns:
dict: Filtered headers based on pass_through_headers.
"""
if not pass_through_headers:
return {}
extracted_headers: dict[str, str] = {}
for key in pass_through_headers:
if key in headers:
extracted_headers[key] = headers[key]
else:
# fastapi makes all header keys lowercase, handling that here
lowercase_key = key.lower()
if lowercase_key in headers:
extracted_headers[lowercase_key] = headers[lowercase_key]
return extracted_headers

View File

@@ -18,6 +18,10 @@ from danswer.chat.models import MessageResponseIDInfo
from danswer.chat.models import MessageSpecificCitations
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
@@ -101,6 +105,7 @@ from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.utils import compute_all_tool_tokens
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.headers import header_dict_to_header_list
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
@@ -272,6 +277,7 @@ def stream_chat_message_objects(
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
enforce_chat_session_id_for_search_docs: bool = True,
) -> ChatPacketStream:
@@ -560,7 +566,26 @@ def stream_chat_message_objects(
and llm.config.api_key
and llm.config.model_provider == "openai"
):
img_generation_llm_config = llm.config
img_generation_llm_config = LLMConfig(
model_provider=llm.config.model_provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=llm.config.api_key,
api_base=llm.config.api_base,
api_version=llm.config.api_version,
)
elif (
llm.config.model_provider == "azure"
and AZURE_DALLE_API_KEY is not None
):
img_generation_llm_config = LLMConfig(
model_provider="azure",
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
temperature=GEN_AI_TEMPERATURE,
api_key=AZURE_DALLE_API_KEY,
api_base=AZURE_DALLE_API_BASE,
api_version=AZURE_DALLE_API_VERSION,
)
else:
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
@@ -579,7 +604,7 @@ def stream_chat_message_objects(
)
img_generation_llm_config = LLMConfig(
model_provider=openai_provider.provider,
model_name=openai_provider.default_model_name,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=openai_provider.api_key,
api_base=openai_provider.api_base,
@@ -591,6 +616,7 @@ def stream_chat_message_objects(
api_base=img_generation_llm_config.api_base,
api_version=img_generation_llm_config.api_version,
additional_headers=litellm_additional_headers,
model=img_generation_llm_config.model_name,
)
]
elif tool_cls.__name__ == InternetSearchTool.__name__:
@@ -615,7 +641,12 @@ def stream_chat_message_objects(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
),
custom_headers=db_tool_model.custom_headers,
custom_headers=(db_tool_model.custom_headers or [])
+ (
header_dict_to_header_list(
custom_tool_additional_headers or {}
)
),
),
)
@@ -838,6 +869,7 @@ def stream_chat_message(
user: User | None,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
) -> Iterator[str]:
with get_session_context_manager() as db_session:
@@ -847,6 +879,7 @@ def stream_chat_message(
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
custom_tool_additional_headers=custom_tool_additional_headers,
is_connected=is_connected,
)
for obj in objects:

View File

@@ -53,7 +53,6 @@ MASK_CREDENTIAL_PREFIX = (
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
)
SESSION_EXPIRE_TIME_SECONDS = int(
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
) # 7 days
@@ -116,10 +115,16 @@ VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
VESPA_CONFIG_SERVER_HOST = os.environ.get("VESPA_CONFIG_SERVER_HOST") or VESPA_HOST
VESPA_PORT = os.environ.get("VESPA_PORT") or "8081"
VESPA_TENANT_PORT = os.environ.get("VESPA_TENANT_PORT") or "19071"
VESPA_CLOUD_URL = os.environ.get("VESPA_CLOUD_URL", "")
# The default below is for dockerized deployment
VESPA_DEPLOYMENT_ZIP = (
os.environ.get("VESPA_DEPLOYMENT_ZIP") or "/app/danswer/vespa-app.zip"
)
VESPA_CLOUD_CERT_PATH = os.environ.get("VESPA_CLOUD_CERT_PATH")
VESPA_CLOUD_KEY_PATH = os.environ.get("VESPA_CLOUD_KEY_PATH")
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
try:
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE", 16))
@@ -135,7 +140,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
os.environ.get("POSTGRES_PASSWORD") or "password"
)
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
POSTGRES_API_SERVER_POOL_SIZE = int(
@@ -401,6 +406,11 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
os.environ.get("CUSTOM_ANSWER_VALIDITY_CONDITIONS", "[]")
)
VESPA_REQUEST_TIMEOUT = int(os.environ.get("VESPA_REQUEST_TIMEOUT") or "5")
SYSTEM_RECURSION_LIMIT = int(os.environ.get("SYSTEM_RECURSION_LIMIT") or "1000")
PARSE_WITH_TRAFILATURA = os.environ.get("PARSE_WITH_TRAFILATURA", "").lower() == "true"
#####
# Enterprise Edition Configs
@@ -413,10 +423,38 @@ ENTERPRISE_EDITION_ENABLED = (
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
)
# Azure DALL-E Configurations
AZURE_DALLE_API_VERSION = os.environ.get("AZURE_DALLE_API_VERSION")
AZURE_DALLE_API_KEY = os.environ.get("AZURE_DALLE_API_KEY")
AZURE_DALLE_API_BASE = os.environ.get("AZURE_DALLE_API_BASE")
AZURE_DALLE_DEPLOYMENT_NAME = os.environ.get("AZURE_DALLE_DEPLOYMENT_NAME")
# Cloud configuration
# Multi-tenancy configuration
MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "")
# Use managed Vespa (Vespa Cloud). If set, must also set VESPA_CLOUD_URL, VESPA_CLOUD_CERT_PATH and VESPA_CLOUD_KEY_PATH
MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true"
DATA_PLANE_SECRET = os.environ.get("DATA_PLANE_SECRET", "")
EXPECTED_API_KEY = os.environ.get("EXPECTED_API_KEY", "")
ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true"
# Security and authentication
SECRET_JWT_KEY = os.environ.get(
"SECRET_JWT_KEY", ""
) # Used for encryption of the JWT token for user's tenant context
DATA_PLANE_SECRET = os.environ.get(
"DATA_PLANE_SECRET", ""
) # Used for secure communication between the control and data plane
EXPECTED_API_KEY = os.environ.get(
"EXPECTED_API_KEY", ""
) # Additional security check for the control plane API
# API configuration
CONTROL_PLANE_API_BASE_URL = os.environ.get(
"CONTROL_PLANE_API_BASE_URL", "http://localhost:8082"
)
# JWT configuration
JWT_ALGORITHM = "HS256"

View File

@@ -31,6 +31,9 @@ DISABLED_GEN_AI_MSG = (
"You can still use Danswer as a search engine."
)
# Prefix used for all tenant ids
TENANT_ID_PREFIX = "tenant_"
# Postgres connection constants for application_name
POSTGRES_WEB_APP_NAME = "web"
POSTGRES_INDEXER_APP_NAME = "indexer"
@@ -39,6 +42,8 @@ POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"
POSTGRES_DEFAULT_SCHEMA = "public"
@@ -70,6 +75,16 @@ KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_INDEXING_LOCK_TIMEOUT = 60 * 60 # 60 min
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
class DocumentSource(str, Enum):
# Special case, document passed in via Danswer APIs without specifying a source type
@@ -115,8 +130,13 @@ class DocumentSource(str, Enum):
NOT_APPLICABLE = "not_applicable"
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
class NotificationType(str, Enum):
REINDEX = "reindex"
PERSONA_SHARED = "persona_shared"
TRIAL_ENDS_TWO_DAYS = "two_day_trial_ending" # 2 days left in trial
class BlobType(str, Enum):
@@ -141,6 +161,9 @@ class AuthType(str, Enum):
OIDC = "oidc"
SAML = "saml"
# google auth and basic
CLOUD = "cloud"
class SessionType(str, Enum):
CHAT = "Chat"
@@ -190,14 +213,19 @@ class DanswerCeleryQueues:
VESPA_METADATA_SYNC = "vespa_metadata_sync"
CONNECTOR_DELETION = "connector_deletion"
CONNECTOR_PRUNING = "connector_pruning"
CONNECTOR_INDEXING = "connector_indexing"
class DanswerRedisLocks:
PRIMARY_WORKER = "da_lock:primary_worker"
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat"
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
PRUNING_LOCK_PREFIX = "da_lock:pruning"
INDEXING_METADATA_PREFIX = "da_metadata:indexing"
class DanswerCeleryPriority(int, Enum):

View File

@@ -0,0 +1,22 @@
import json
import os
# if specified, will pass through request headers to the call to API calls made by custom tools
CUSTOM_TOOL_PASS_THROUGH_HEADERS: list[str] | None = None
_CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get(
"CUSTOM_TOOL_PASS_THROUGH_HEADERS"
)
if _CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW:
try:
CUSTOM_TOOL_PASS_THROUGH_HEADERS = json.loads(
_CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW
)
except Exception:
# need to import here to avoid circular imports
from danswer.utils.logger import setup_logger
logger = setup_logger()
logger.error(
"Failed to parse CUSTOM_TOOL_PASS_THROUGH_HEADERS, must be a valid JSON object"
)

View File

@@ -6,7 +6,8 @@ from datetime import datetime
from datetime import timezone
from functools import lru_cache
from typing import Any
from typing import cast
from urllib.parse import parse_qs
from urllib.parse import urlparse
import bs4
from atlassian import Confluence # type:ignore
@@ -32,7 +33,6 @@ from danswer.connectors.confluence.rate_limit_handler import (
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
@@ -56,8 +56,40 @@ NO_PARENT_OR_NO_PERMISSIONS_ERROR_STR = (
)
class DanswerConfluence(Confluence):
"""
This is a custom Confluence class that overrides the default Confluence class to add a custom CQL method.
This is necessary because the default Confluence class does not properly support cql expansions.
"""
def __init__(self, url: str, *args: Any, **kwargs: Any) -> None:
super(DanswerConfluence, self).__init__(url, *args, **kwargs)
def danswer_cql(
self,
cql: str,
expand: str | None = None,
cursor: str | None = None,
limit: int = 500,
include_archived_spaces: bool = False,
) -> dict[str, Any]:
url_suffix = f"rest/api/content/search?cql={cql}"
if expand:
url_suffix += f"&expand={expand}"
if cursor:
url_suffix += f"&cursor={cursor}"
url_suffix += f"&limit={limit}"
if include_archived_spaces:
url_suffix += "&includeArchivedSpaces=true"
try:
response = self.get(url_suffix)
return response
except Exception as e:
raise e
@lru_cache()
def _get_user(user_id: str, confluence_client: Confluence) -> str:
def _get_user(user_id: str, confluence_client: DanswerConfluence) -> str:
"""Get Confluence Display Name based on the account-id or userkey value
Args:
@@ -73,6 +105,7 @@ def _get_user(user_id: str, confluence_client: Confluence) -> str:
confluence_client.get_user_details_by_accountid
)
try:
logger.info(f"_get_user - get_user_details_by_accountid: id={user_id}")
return get_user_details_by_accountid(user_id).get("displayName", user_not_found)
except Exception as e:
logger.warning(
@@ -81,7 +114,7 @@ def _get_user(user_id: str, confluence_client: Confluence) -> str:
return user_not_found
def parse_html_page(text: str, confluence_client: Confluence) -> str:
def parse_html_page(text: str, confluence_client: DanswerConfluence) -> str:
"""Parse a Confluence html page and replace the 'user Id' by the real
User Display Name
@@ -112,7 +145,7 @@ def parse_html_page(text: str, confluence_client: Confluence) -> str:
def _comment_dfs(
comments_str: str,
comment_pages: Collection[dict[str, Any]],
confluence_client: Confluence,
confluence_client: DanswerConfluence,
) -> str:
get_page_child_by_type = make_confluence_call_handle_rate_limit(
confluence_client.get_page_child_by_type
@@ -124,6 +157,9 @@ def _comment_dfs(
comment_html, confluence_client
)
try:
logger.info(
f"_comment_dfs - get_page_by_child_type: id={comment_page['id']}"
)
child_comment_pages = get_page_child_by_type(
comment_page["id"],
type="comment",
@@ -163,130 +199,103 @@ class RecursiveIndexer:
index_recursively: bool,
origin_page_id: str,
) -> None:
self.batch_size = 1
# batch_size
self.batch_size = batch_size
self.confluence_client = confluence_client
self.index_recursively = index_recursively
self.origin_page_id = origin_page_id
self.pages = self.recurse_children_pages(0, self.origin_page_id)
self.pages = self.recurse_children_pages(self.origin_page_id)
def get_origin_page(self) -> list[dict[str, Any]]:
return [self._fetch_origin_page()]
def get_pages(self, ind: int, size: int) -> list[dict]:
if ind * size > len(self.pages):
return []
return self.pages[ind * size : (ind + 1) * size]
def get_pages(self) -> list[dict[str, Any]]:
return self.pages
def _fetch_origin_page(
self,
) -> dict[str, Any]:
def _fetch_origin_page(self) -> dict[str, Any]:
get_page_by_id = make_confluence_call_handle_rate_limit(
self.confluence_client.get_page_by_id
)
try:
logger.info(
f"_fetch_origin_page - get_page_by_id: id={self.origin_page_id}"
)
origin_page = get_page_by_id(
self.origin_page_id, expand="body.storage.value,version"
self.origin_page_id, expand="body.storage.value,version,space"
)
return origin_page
except Exception as e:
logger.warning(
f"Appending orgin page with id {self.origin_page_id} failed: {e}"
except Exception:
logger.exception(
f"Appending origin page with id {self.origin_page_id} failed."
)
return {}
def recurse_children_pages(
self,
start_ind: int,
page_id: str,
) -> list[dict[str, Any]]:
pages: list[dict[str, Any]] = []
current_level_pages: list[dict[str, Any]] = []
next_level_pages: list[dict[str, Any]] = []
queue: list[str] = [page_id]
visited_pages: set[str] = set()
# Initial fetch of first level children
index = start_ind
while batch := self._fetch_single_depth_child_pages(
index, self.batch_size, page_id
):
current_level_pages.extend(batch)
index += len(batch)
pages.extend(current_level_pages)
# Recursively index children and children's children, etc.
while current_level_pages:
for child in current_level_pages:
child_index = 0
while child_batch := self._fetch_single_depth_child_pages(
child_index, self.batch_size, child["id"]
):
next_level_pages.extend(child_batch)
child_index += len(child_batch)
pages.extend(next_level_pages)
current_level_pages = next_level_pages
next_level_pages = []
try:
origin_page = self._fetch_origin_page()
pages.append(origin_page)
except Exception as e:
logger.warning(f"Appending origin page with id {page_id} failed: {e}")
return pages
def _fetch_single_depth_child_pages(
self, start_ind: int, batch_size: int, page_id: str
) -> list[dict[str, Any]]:
child_pages: list[dict[str, Any]] = []
get_page_by_id = make_confluence_call_handle_rate_limit(
self.confluence_client.get_page_by_id
)
get_page_child_by_type = make_confluence_call_handle_rate_limit(
self.confluence_client.get_page_child_by_type
)
try:
child_page = get_page_child_by_type(
page_id,
type="page",
start=start_ind,
limit=batch_size,
expand="body.storage.value,version",
)
while queue:
current_page_id = queue.pop(0)
if current_page_id in visited_pages:
continue
visited_pages.add(current_page_id)
child_pages.extend(child_page)
return child_pages
try:
# Fetch the page itself
logger.info(
f"recurse_children_pages - get_page_by_id: id={current_page_id}"
)
page = get_page_by_id(
current_page_id, expand="body.storage.value,version,space"
)
pages.append(page)
except Exception:
logger.exception(f"Failed to fetch page {current_page_id}.")
continue
except Exception:
logger.warning(
f"Batch failed with page {page_id} at offset {start_ind} "
f"with size {batch_size}, processing pages individually..."
)
if not self.index_recursively:
continue
for i in range(batch_size):
ind = start_ind + i
try:
child_page = get_page_child_by_type(
page_id,
type="page",
start=ind,
limit=1,
expand="body.storage.value,version",
)
child_pages.extend(child_page)
except Exception as e:
logger.warning(f"Page {page_id} at offset {ind} failed: {e}")
raise e
# Fetch child pages
start = 0
while True:
logger.info(
f"recurse_children_pages - get_page_by_child_type: id={current_page_id}"
)
child_pages_response = get_page_child_by_type(
current_page_id,
type="page",
start=start,
limit=self.batch_size,
expand="",
)
if not child_pages_response:
break
for child_page in child_pages_response:
child_page_id = child_page["id"]
queue.append(child_page_id)
start += len(child_pages_response)
return child_pages
return pages
class ConfluenceConnector(LoadConnector, PollConnector):
def __init__(
self,
wiki_base: str,
space: str,
is_cloud: bool,
space: str = "",
page_id: str = "",
index_recursively: bool = True,
batch_size: int = INDEX_BATCH_SIZE,
@@ -295,104 +304,167 @@ class ConfluenceConnector(LoadConnector, PollConnector):
# skip it. This is generally used to avoid indexing extra sensitive
# pages.
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
cql_query: str | None = None,
) -> None:
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
self.labels_to_skip = set(labels_to_skip)
self.recursive_indexer: RecursiveIndexer | None = None
self.index_recursively = index_recursively
self.index_recursively = False if cql_query else index_recursively
# Remove trailing slash from wiki_base if present
self.wiki_base = wiki_base.rstrip("/")
self.space = space
self.page_id = page_id
self.page_id = "" if cql_query else page_id
self.space_level_scan = bool(not self.page_id)
self.is_cloud = is_cloud
self.space_level_scan = False
self.confluence_client: Confluence | None = None
self.confluence_client: DanswerConfluence | None = None
if self.page_id is None or self.page_id == "":
self.space_level_scan = True
# if a cql_query is provided, we will use it to fetch the pages
# if no cql_query is provided, we will use the space to fetch the pages
# if no space is provided and no cql_query, we will default to fetching all pages, regardless of space
if cql_query:
self.cql_query = cql_query
elif space:
self.cql_query = f"type=page and space='{space}'"
else:
self.cql_query = "type=page"
logger.info(
f"wiki_base: {self.wiki_base}, space: {self.space}, page_id: {self.page_id},"
+ f" space_level_scan: {self.space_level_scan}, index_recursively: {self.index_recursively}"
f"wiki_base: {self.wiki_base}, space: {space}, page_id: {self.page_id},"
+ f" space_level_scan: {self.space_level_scan}, index_recursively: {self.index_recursively},"
+ f" cql_query: {self.cql_query}"
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
username = credentials["confluence_username"]
access_token = credentials["confluence_access_token"]
self.confluence_client = Confluence(
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
# for a list of other hidden constructor args
self.confluence_client = DanswerConfluence(
url=self.wiki_base,
# passing in username causes issues for Confluence data center
username=username if self.is_cloud else None,
password=access_token if self.is_cloud else None,
token=access_token if not self.is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=60,
max_backoff_seconds=60,
)
return None
def _fetch_pages(
self,
confluence_client: Confluence,
start_ind: int,
) -> list[dict[str, Any]]:
def _fetch_space(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
get_all_pages_from_space = make_confluence_call_handle_rate_limit(
confluence_client.get_all_pages_from_space
cursor: str | None,
) -> tuple[list[dict[str, Any]], str | None]:
if self.confluence_client is None:
raise Exception("Confluence client is not initialized")
def _fetch_space(
cursor: str | None, batch_size: int
) -> tuple[list[dict[str, Any]], str | None]:
if not self.confluence_client:
raise Exception("Confluence client is not initialized")
get_all_pages = make_confluence_call_handle_rate_limit(
self.confluence_client.danswer_cql
)
include_archived_spaces = (
CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES
if not self.is_cloud
else False
)
try:
return get_all_pages_from_space(
self.space,
start=start_ind,
limit=batch_size,
status=(
None if CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES else "current"
),
expand="body.storage.value,version",
logger.info(
f"_fetch_space - get_all_pages: cursor={cursor} limit={batch_size}"
)
response = get_all_pages(
cql=self.cql_query,
cursor=cursor,
limit=batch_size,
expand="body.storage.value,version,space",
include_archived_spaces=include_archived_spaces,
)
pages = response.get("results", [])
next_cursor = None
if "_links" in response and "next" in response["_links"]:
next_link = response["_links"]["next"]
parsed_url = urlparse(next_link)
query_params = parse_qs(parsed_url.query)
cursor_list = query_params.get("cursor", [])
if cursor_list:
next_cursor = cursor_list[0]
return pages, next_cursor
except Exception:
logger.warning(
f"Batch failed with space {self.space} at offset {start_ind} "
f"with size {batch_size}, processing pages individually..."
f"Batch failed with cql {self.cql_query} with cursor {cursor} "
f"and size {batch_size}, processing pages individually..."
)
view_pages: list[dict[str, Any]] = []
for i in range(self.batch_size):
for _ in range(self.batch_size):
try:
# Could be that one of the pages here failed due to this bug:
# https://jira.atlassian.com/browse/CONFCLOUD-76433
view_pages.extend(
get_all_pages_from_space(
self.space,
start=start_ind + i,
limit=1,
status=(
None
if CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES
else "current"
),
expand="body.storage.value,version",
)
logger.info(
f"_fetch_space - get_all_pages: cursor={cursor} limit=1"
)
response = get_all_pages(
cql=self.cql_query,
cursor=cursor,
limit=1,
expand="body.view.value,version,space",
include_archived_spaces=include_archived_spaces,
)
pages = response.get("results", [])
view_pages.extend(pages)
if "_links" in response and "next" in response["_links"]:
next_link = response["_links"]["next"]
parsed_url = urlparse(next_link)
query_params = parse_qs(parsed_url.query)
cursor_list = query_params.get("cursor", [])
if cursor_list:
cursor = cursor_list[0]
else:
cursor = None
else:
cursor = None
break
except HTTPError as e:
logger.warning(
f"Page failed with space {self.space} at offset {start_ind + i}, "
f"Page failed with cql {self.cql_query} with cursor {cursor}, "
f"trying alternative expand option: {e}"
)
# Use view instead, which captures most info but is less complete
view_pages.extend(
get_all_pages_from_space(
self.space,
start=start_ind + i,
limit=1,
expand="body.view.value,version",
)
logger.info(
f"_fetch_space - get_all_pages - trying alternative expand: cursor={cursor} limit=1"
)
response = get_all_pages(
cql=self.cql_query,
cursor=cursor,
limit=1,
expand="body.view.value,version,space",
)
pages = response.get("results", [])
view_pages.extend(pages)
if "_links" in response and "next" in response["_links"]:
next_link = response["_links"]["next"]
parsed_url = urlparse(next_link)
query_params = parse_qs(parsed_url.query)
cursor_list = query_params.get("cursor", [])
if cursor_list:
cursor = cursor_list[0]
else:
cursor = None
else:
cursor = None
break
return view_pages
return view_pages, cursor
def _fetch_page() -> tuple[list[dict[str, Any]], str | None]:
if self.confluence_client is None:
raise Exception("Confluence client is not initialized")
def _fetch_page(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
if self.recursive_indexer is None:
self.recursive_indexer = RecursiveIndexer(
origin_page_id=self.page_id,
@@ -401,41 +473,22 @@ class ConfluenceConnector(LoadConnector, PollConnector):
index_recursively=self.index_recursively,
)
if self.index_recursively:
return self.recursive_indexer.get_pages(start_ind, batch_size)
else:
return self.recursive_indexer.get_origin_page()
pages: list[dict[str, Any]] = []
pages = self.recursive_indexer.get_pages()
return pages, None # Since we fetched all pages, no cursor
try:
pages = (
_fetch_space(start_ind, self.batch_size)
pages, next_cursor = (
_fetch_space(cursor, self.batch_size)
if self.space_level_scan
else _fetch_page(start_ind, self.batch_size)
else _fetch_page()
)
return pages
return pages, next_cursor
except Exception as e:
if not self.continue_on_failure:
raise e
# error checking phase, only reachable if `self.continue_on_failure=True`
for i in range(self.batch_size):
try:
pages = (
_fetch_space(start_ind, self.batch_size)
if self.space_level_scan
else _fetch_page(start_ind, self.batch_size)
)
return pages
except Exception:
logger.exception(
"Ran into exception when fetching pages from Confluence"
)
return pages
logger.exception("Ran into exception when fetching pages from Confluence")
return [], None
def _fetch_comments(self, confluence_client: Confluence, page_id: str) -> str:
get_page_child_by_type = make_confluence_call_handle_rate_limit(
@@ -443,24 +496,22 @@ class ConfluenceConnector(LoadConnector, PollConnector):
)
try:
comment_pages = cast(
Collection[dict[str, Any]],
logger.info(f"_fetch_comments - get_page_child_by_type: id={page_id}")
comment_pages = list(
get_page_child_by_type(
page_id,
type="comment",
start=None,
limit=None,
expand="body.storage.value",
),
)
)
return _comment_dfs("", comment_pages, confluence_client)
except Exception as e:
if not self.continue_on_failure:
raise e
logger.exception(
"Ran into exception when fetching comments from Confluence"
)
logger.exception("Fetching comments from Confluence exceptioned")
return ""
def _fetch_labels(self, confluence_client: Confluence, page_id: str) -> list[str]:
@@ -468,13 +519,14 @@ class ConfluenceConnector(LoadConnector, PollConnector):
confluence_client.get_page_labels
)
try:
logger.info(f"_fetch_labels - get_page_labels: id={page_id}")
labels_response = get_page_labels(page_id)
return [label["name"] for label in labels_response["results"]]
except Exception as e:
if not self.continue_on_failure:
raise e
logger.exception("Ran into exception when fetching labels from Confluence")
logger.exception("Fetching labels from Confluence exceptioned")
return []
@classmethod
@@ -511,6 +563,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
)
return None
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
response = confluence_client._session.get(download_link)
if response.status_code != 200:
logger.warning(
@@ -534,22 +587,22 @@ class ConfluenceConnector(LoadConnector, PollConnector):
return extracted_text
def _fetch_attachments(
self, confluence_client: Confluence, page_id: str, files_in_used: list[str]
self, confluence_client: Confluence, page_id: str, files_in_use: list[str]
) -> tuple[str, list[dict[str, Any]]]:
unused_attachments: list = []
unused_attachments: list[dict[str, Any]] = []
files_attachment_content: list[str] = []
get_attachments_from_content = make_confluence_call_handle_rate_limit(
confluence_client.get_attachments_from_content
)
files_attachment_content: list = []
try:
expand = "history.lastUpdated,metadata.labels"
attachments_container = get_attachments_from_content(
page_id, start=0, limit=500, expand=expand
page_id, start=None, limit=None, expand=expand
)
for attachment in attachments_container["results"]:
if attachment["title"] not in files_in_used:
for attachment in attachments_container.get("results", []):
if attachment["title"] not in files_in_use:
unused_attachments.append(attachment)
continue
@@ -567,36 +620,33 @@ class ConfluenceConnector(LoadConnector, PollConnector):
f"User does not have access to attachments on page '{page_id}'"
)
return "", []
if not self.continue_on_failure:
raise e
logger.exception(
f"Ran into exception when fetching attachments from Confluence: {e}"
)
logger.exception("Fetching attachments from Confluence exceptioned.")
return "\n".join(files_attachment_content), unused_attachments
def _get_doc_batch(
self, start_ind: int, time_filter: Callable[[datetime], bool] | None = None
) -> tuple[list[Document], list[dict[str, Any]], int]:
doc_batch: list[Document] = []
self, cursor: str | None
) -> tuple[list[Any], str | None, list[dict[str, Any]]]:
if self.confluence_client is None:
raise Exception("Confluence client is not initialized")
doc_batch: list[Any] = []
unused_attachments: list[dict[str, Any]] = []
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
batch = self._fetch_pages(self.confluence_client, start_ind)
batch, next_cursor = self._fetch_pages(cursor)
for page in batch:
last_modified = _datetime_from_string(page["version"]["when"])
author = cast(str | None, page["version"].get("by", {}).get("email"))
if time_filter and not time_filter(last_modified):
continue
author = page["version"].get("by", {}).get("email")
page_id = page["id"]
if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
page_labels = self._fetch_labels(self.confluence_client, page_id)
else:
page_labels = []
# check disallowed labels
if self.labels_to_skip:
@@ -606,7 +656,6 @@ class ConfluenceConnector(LoadConnector, PollConnector):
f"Page with ID '{page_id}' has a label which has been "
f"designated as disallowed: {label_intersection}. Skipping."
)
continue
page_html = (
@@ -621,16 +670,18 @@ class ConfluenceConnector(LoadConnector, PollConnector):
continue
page_text = parse_html_page(page_html, self.confluence_client)
files_in_used = get_used_attachments(page_html)
files_in_use = get_used_attachments(page_html)
attachment_text, unused_page_attachments = self._fetch_attachments(
self.confluence_client, page_id, files_in_used
self.confluence_client, page_id, files_in_use
)
unused_attachments.extend(unused_page_attachments)
page_text += "\n" + attachment_text if attachment_text else ""
comments_text = self._fetch_comments(self.confluence_client, page_id)
page_text += comments_text
doc_metadata: dict[str, str | list[str]] = {"Wiki Space Name": self.space}
doc_metadata: dict[str, str | list[str]] = {
"Wiki Space Name": page["space"]["name"]
}
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING and page_labels:
doc_metadata["labels"] = page_labels
@@ -649,8 +700,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
)
return (
doc_batch,
next_cursor,
unused_attachments,
len(batch),
)
def _get_attachment_batch(
@@ -658,8 +709,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
start_ind: int,
attachments: list[dict[str, Any]],
time_filter: Callable[[datetime], bool] | None = None,
) -> tuple[list[Document], int]:
doc_batch: list[Document] = []
) -> tuple[list[Any], int]:
doc_batch: list[Any] = []
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
@@ -687,7 +738,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
creator_email = attachment["history"]["createdBy"].get("email")
comment = attachment["metadata"].get("comment", "")
doc_metadata: dict[str, str | list[str]] = {"comment": comment}
doc_metadata: dict[str, Any] = {"comment": comment}
attachment_labels: list[str] = []
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
@@ -714,69 +765,36 @@ class ConfluenceConnector(LoadConnector, PollConnector):
return doc_batch, end_ind - start_ind
def load_from_state(self) -> GenerateDocumentsOutput:
unused_attachments = []
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
start_ind = 0
while True:
doc_batch, unused_attachments_batch, num_pages = self._get_doc_batch(
start_ind
)
unused_attachments.extend(unused_attachments_batch)
start_ind += num_pages
if doc_batch:
yield doc_batch
if num_pages < self.batch_size:
break
start_ind = 0
while True:
attachment_batch, num_attachments = self._get_attachment_batch(
start_ind, unused_attachments
)
start_ind += num_attachments
if attachment_batch:
yield attachment_batch
if num_attachments < self.batch_size:
break
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
def _handle_batch_retrieval(
self,
start: float | None = None,
end: float | None = None,
) -> GenerateDocumentsOutput:
unused_attachments = []
start_time = datetime.fromtimestamp(start, tz=timezone.utc) if start else None
end_time = datetime.fromtimestamp(end, tz=timezone.utc) if end else None
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
start_time = datetime.fromtimestamp(start, tz=timezone.utc)
end_time = datetime.fromtimestamp(end, tz=timezone.utc)
start_ind = 0
unused_attachments: list[dict[str, Any]] = []
cursor = None
while True:
doc_batch, unused_attachments_batch, num_pages = self._get_doc_batch(
start_ind, time_filter=lambda t: start_time <= t <= end_time
)
unused_attachments.extend(unused_attachments_batch)
start_ind += num_pages
doc_batch, cursor, new_unused_attachments = self._get_doc_batch(cursor)
unused_attachments.extend(new_unused_attachments)
if doc_batch:
yield doc_batch
if num_pages < self.batch_size:
if not cursor:
break
# Process attachments if any
start_ind = 0
while True:
attachment_batch, num_attachments = self._get_attachment_batch(
start_ind,
unused_attachments,
time_filter=lambda t: start_time <= t <= end_time,
start_ind=start_ind,
attachments=unused_attachments,
time_filter=(lambda t: start_time <= t <= end_time)
if start_time and end_time
else None,
)
start_ind += num_attachments
if attachment_batch:
yield attachment_batch
@@ -784,6 +802,12 @@ class ConfluenceConnector(LoadConnector, PollConnector):
if num_attachments < self.batch_size:
break
def load_from_state(self) -> GenerateDocumentsOutput:
return self._handle_batch_retrieval()
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
return self._handle_batch_retrieval(start=start, end=end)
if __name__ == "__main__":
connector = ConfluenceConnector(

View File

@@ -1,3 +1,4 @@
import math
import time
from collections.abc import Callable
from typing import Any
@@ -21,62 +22,198 @@ class ConfluenceRateLimitError(Exception):
pass
# commenting out while we try using confluence's rate limiter instead
# # https://developer.atlassian.com/cloud/confluence/rate-limiting/
# def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
# def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
# max_retries = 5
# starting_delay = 5
# backoff = 2
# # max_delay is used when the server doesn't hand back "Retry-After"
# # and we have to decide the retry delay ourselves
# max_delay = 30 # Atlassian uses max_delay = 30 in their examples
# # max_retry_after is used when we do get a "Retry-After" header
# max_retry_after = 300 # should we really cap the maximum retry delay?
# NEXT_RETRY_KEY = BaseConnector.REDIS_KEY_PREFIX + "confluence_next_retry"
# # for testing purposes, rate limiting is written to fall back to a simpler
# # rate limiting approach when redis is not available
# r = get_redis_client(tenant_id=tenant_id)
# for attempt in range(max_retries):
# try:
# # if multiple connectors are waiting for the next attempt, there could be an issue
# # where many connectors are "released" onto the server at the same time.
# # That's not ideal ... but coming up with a mechanism for queueing
# # all of these connectors is a bigger problem that we want to take on
# # right now
# try:
# next_attempt = r.get(NEXT_RETRY_KEY)
# if next_attempt is None:
# next_attempt = 0
# else:
# next_attempt = int(cast(int, next_attempt))
# # TODO: all connectors need to be interruptible moving forward
# while time.monotonic() < next_attempt:
# time.sleep(1)
# except ConnectionError:
# pass
# return confluence_call(*args, **kwargs)
# except HTTPError as e:
# # Check if the response or headers are None to avoid potential AttributeError
# if e.response is None or e.response.headers is None:
# logger.warning("HTTPError with `None` as response or as headers")
# raise e
# retry_after_header = e.response.headers.get("Retry-After")
# if (
# e.response.status_code == 429
# or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
# ):
# retry_after = None
# if retry_after_header is not None:
# try:
# retry_after = int(retry_after_header)
# except ValueError:
# pass
# if retry_after is not None:
# if retry_after > max_retry_after:
# logger.warning(
# f"Clamping retry_after from {retry_after} to {max_delay} seconds..."
# )
# retry_after = max_delay
# logger.warning(
# f"Rate limit hit. Retrying after {retry_after} seconds..."
# )
# try:
# r.set(
# NEXT_RETRY_KEY,
# math.ceil(time.monotonic() + retry_after),
# )
# except ConnectionError:
# pass
# else:
# logger.warning(
# "Rate limit hit. Retrying with exponential backoff..."
# )
# delay = min(starting_delay * (backoff**attempt), max_delay)
# delay_until = math.ceil(time.monotonic() + delay)
# try:
# r.set(NEXT_RETRY_KEY, delay_until)
# except ConnectionError:
# while time.monotonic() < delay_until:
# time.sleep(1)
# else:
# # re-raise, let caller handle
# raise
# except AttributeError as e:
# # Some error within the Confluence library, unclear why it fails.
# # Users reported it to be intermittent, so just retry
# logger.warning(f"Confluence Internal Error, retrying... {e}")
# delay = min(starting_delay * (backoff**attempt), max_delay)
# delay_until = math.ceil(time.monotonic() + delay)
# try:
# r.set(NEXT_RETRY_KEY, delay_until)
# except ConnectionError:
# while time.monotonic() < delay_until:
# time.sleep(1)
# if attempt == max_retries - 1:
# raise e
# return cast(F, wrapped_call)
def _handle_http_error(e: HTTPError, attempt: int) -> int:
MIN_DELAY = 2
MAX_DELAY = 60
STARTING_DELAY = 5
BACKOFF = 2
# Check if the response or headers are None to avoid potential AttributeError
if e.response is None or e.response.headers is None:
logger.warning("HTTPError with `None` as response or as headers")
raise e
if (
e.response.status_code != 429
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
):
raise e
retry_after = None
retry_after_header = e.response.headers.get("Retry-After")
if retry_after_header is not None:
try:
retry_after = int(retry_after_header)
if retry_after > MAX_DELAY:
logger.warning(
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
)
retry_after = MAX_DELAY
if retry_after < MIN_DELAY:
retry_after = MIN_DELAY
except ValueError:
pass
if retry_after is not None:
logger.warning(
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
)
delay = retry_after
else:
logger.warning(
"Rate limiting without retry header. Retrying with exponential backoff..."
)
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
delay_until = math.ceil(time.monotonic() + delay)
return delay_until
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
# this uses the native rate limiting option provided by the
# confluence client and otherwise applies a simpler set of error handling
def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
max_retries = 5
starting_delay = 5
backoff = 2
max_delay = 600
MAX_RETRIES = 5
TIMEOUT = 3600
timeout_at = time.monotonic() + TIMEOUT
for attempt in range(MAX_RETRIES):
if time.monotonic() > timeout_at:
raise TimeoutError(
f"Confluence call attempts took longer than {TIMEOUT} seconds."
)
for attempt in range(max_retries):
try:
# we're relying more on the client to rate limit itself
# and applying our own retries in a more specific set of circumstances
return confluence_call(*args, **kwargs)
except HTTPError as e:
# Check if the response or headers are None to avoid potential AttributeError
if e.response is None or e.response.headers is None:
logger.warning("HTTPError with `None` as response or as headers")
raise e
retry_after_header = e.response.headers.get("Retry-After")
if (
e.response.status_code == 429
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
):
retry_after = None
if retry_after_header is not None:
try:
retry_after = int(retry_after_header)
except ValueError:
pass
if retry_after is not None:
if retry_after > 600:
logger.warning(
f"Clamping retry_after from {retry_after} to {max_delay} seconds..."
)
retry_after = max_delay
logger.warning(
f"Rate limit hit. Retrying after {retry_after} seconds..."
)
time.sleep(retry_after)
else:
logger.warning(
"Rate limit hit. Retrying with exponential backoff..."
)
delay = min(starting_delay * (backoff**attempt), max_delay)
time.sleep(delay)
else:
# re-raise, let caller handle
raise
delay_until = _handle_http_error(e, attempt)
while time.monotonic() < delay_until:
# in the future, check a signal here to exit
time.sleep(1)
except AttributeError as e:
# Some error within the Confluence library, unclear why it fails.
# Users reported it to be intermittent, so just retry
logger.warning(f"Confluence Internal Error, retrying... {e}")
delay = min(starting_delay * (backoff**attempt), max_delay)
time.sleep(delay)
if attempt == max_retries - 1:
if attempt == MAX_RETRIES - 1:
raise e
logger.exception(
"Confluence Client raised an AttributeError. Retrying..."
)
time.sleep(5)
return cast(F, wrapped_call)

View File

@@ -4,6 +4,7 @@ from typing import Type
from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import DocumentSourceRequiringTenantContext
from danswer.connectors.asana.connector import AsanaConnector
from danswer.connectors.axero.connector import AxeroConnector
from danswer.connectors.blob.connector import BlobStorageConnector
@@ -134,8 +135,13 @@ def instantiate_connector(
input_type: InputType,
connector_specific_config: dict[str, Any],
credential: Credential,
tenant_id: str | None = None,
) -> BaseConnector:
connector_class = identify_connector_class(source, input_type)
if source in DocumentSourceRequiringTenantContext:
connector_specific_config["tenant_id"] = tenant_id
connector = connector_class(**connector_specific_config)
new_credentials = connector.load_credentials(credential.credential_json)

View File

@@ -10,13 +10,14 @@ from sqlalchemy.orm import Session
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import get_session_with_tenant
from danswer.file_processing.extract_file_text import check_file_ext_is_valid
from danswer.file_processing.extract_file_text import detect_encoding
from danswer.file_processing.extract_file_text import extract_file_text
@@ -27,6 +28,7 @@ from danswer.file_processing.extract_file_text import read_pdf_file
from danswer.file_processing.extract_file_text import read_text_file
from danswer.file_store.file_store import get_default_file_store
from danswer.utils.logger import setup_logger
from shared_configs.configs import current_tenant_id
logger = setup_logger()
@@ -159,10 +161,12 @@ class LocalFileConnector(LoadConnector):
def __init__(
self,
file_locations: list[Path | str],
tenant_id: str = POSTGRES_DEFAULT_SCHEMA,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.file_locations = [Path(file_location) for file_location in file_locations]
self.batch_size = batch_size
self.tenant_id = tenant_id
self.pdf_pass: str | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
@@ -171,7 +175,9 @@ class LocalFileConnector(LoadConnector):
def load_from_state(self) -> GenerateDocumentsOutput:
documents: list[Document] = []
with Session(get_sqlalchemy_engine()) as db_session:
token = current_tenant_id.set(self.tenant_id)
with get_session_with_tenant(self.tenant_id) as db_session:
for file_path in self.file_locations:
current_datetime = datetime.now(timezone.utc)
files = _read_files_and_metadata(
@@ -193,6 +199,8 @@ class LocalFileConnector(LoadConnector):
if documents:
yield documents
current_tenant_id.reset(token)
if __name__ == "__main__":
connector = LocalFileConnector(file_locations=[os.environ["TEST_FILE"]])

View File

@@ -462,8 +462,34 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
for permission in file["permissions"]
):
continue
try:
text_contents = extract_text(file, service) or ""
except HttpError as e:
reason = (
e.error_details[0]["reason"]
if e.error_details
else e.reason
)
message = (
e.error_details[0]["message"]
if e.error_details
else e.reason
)
text_contents = extract_text(file, service) or ""
# these errors don't represent a failure in the connector, but simply files
# that can't / shouldn't be indexed
ERRORS_TO_CONTINUE_ON = [
"cannotExportFile",
"exportSizeLimitExceeded",
"cannotDownloadFile",
]
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:
logger.warning(
f"Could not export file '{file['name']}' due to '{message}', skipping..."
)
continue
raise
doc_batch.append(
Document(

View File

@@ -11,6 +11,8 @@ GenerateDocumentsOutput = Iterator[list[Document]]
class BaseConnector(abc.ABC):
REDIS_KEY_PREFIX = "da_connector_data:"
@abc.abstractmethod
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
raise NotImplementedError

View File

@@ -45,8 +45,7 @@ class FamilyFileGeneratorInMemory(generate_family_file.FamilyFileGenerator):
if any(x not in generate_family_file.NAME_CHARACTERS for x in name):
raise ValueError(
'ERROR: Name of family "{}" must be ASCII letters and digits [a-zA-Z0-9]',
name,
f'ERROR: Name of family "{name}" must be ASCII letters and digits [a-zA-Z0-9]',
)
if isinstance(dointerwiki, bool):

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import datetime
import itertools
from collections.abc import Generator
from collections.abc import Iterator
from typing import Any
from typing import ClassVar
@@ -19,6 +20,9 @@ from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.mediawiki.family import family_class_dispatch
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.logger import setup_logger
logger = setup_logger()
def pywikibot_timestamp_to_utc_datetime(
@@ -74,7 +78,7 @@ def get_doc_from_page(
sections=sections,
semantic_identifier=page.title(),
metadata={"categories": [category.title() for category in page.categories()]},
id=page.pageid,
id=f"MEDIAWIKI_{page.pageid}_{page.full_url()}",
)
@@ -117,13 +121,18 @@ class MediaWikiConnector(LoadConnector, PollConnector):
# short names can only have ascii letters and digits
self.family = family_class_dispatch(hostname, "Wikipedia Connector")()
self.family = family_class_dispatch(hostname, "WikipediaConnector")()
self.site = pywikibot.Site(fam=self.family, code=language_code)
self.categories = [
pywikibot.Category(self.site, f"Category:{category.replace(' ', '_')}")
for category in categories
]
self.pages = [pywikibot.Page(self.site, page) for page in pages]
self.pages = []
for page in pages:
if not page:
continue
self.pages.append(pywikibot.Page(self.site, page))
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Load credentials for a MediaWiki site.
@@ -169,8 +178,13 @@ class MediaWikiConnector(LoadConnector, PollConnector):
]
# Since we can specify both individual pages and categories, we need to iterate over all of them.
all_pages = itertools.chain(self.pages, *category_pages)
all_pages: Iterator[pywikibot.Page] = itertools.chain(
self.pages, *category_pages
)
for page in all_pages:
logger.info(
f"MediaWikiConnector: title='{page.title()}' url={page.full_url()}"
)
doc_batch.append(
get_doc_from_page(page, self.site, self.document_source_type)
)

View File

@@ -29,6 +29,9 @@ logger = setup_logger()
_NOTION_CALL_TIMEOUT = 30 # 30 seconds
# TODO: Tables need to be ingested, Pages need to have their metadata ingested
@dataclass
class NotionPage:
"""Represents a Notion Page object"""
@@ -40,6 +43,8 @@ class NotionPage:
properties: dict[str, Any]
url: str
database_name: str | None # Only applicable to the database type page (wiki)
def __init__(self, **kwargs: dict[str, Any]) -> None:
names = set([f.name for f in fields(self)])
for k, v in kwargs.items():
@@ -47,6 +52,17 @@ class NotionPage:
setattr(self, k, v)
@dataclass
class NotionBlock:
"""Represents a Notion Block object"""
id: str # Used for the URL
text: str
# In a plaintext representation of the page, how this block should be joined
# with the existing text up to this point, separated out from text for clarity
prefix: str
@dataclass
class NotionSearchResponse:
"""Represents the response from the Notion Search API"""
@@ -62,7 +78,6 @@ class NotionSearchResponse:
setattr(self, k, v)
# TODO - Add the ability to optionally limit to specific Notion databases
class NotionConnector(LoadConnector, PollConnector):
"""Notion Page connector that reads all Notion pages
this integration has been granted access to.
@@ -126,21 +141,47 @@ class NotionConnector(LoadConnector, PollConnector):
@retry(tries=3, delay=1, backoff=2)
def _fetch_page(self, page_id: str) -> NotionPage:
"""Fetch a page from it's ID via the Notion API."""
"""Fetch a page from its ID via the Notion API, retry with database if page fetch fails."""
logger.debug(f"Fetching page for ID '{page_id}'")
block_url = f"https://api.notion.com/v1/pages/{page_id}"
page_url = f"https://api.notion.com/v1/pages/{page_id}"
res = rl_requests.get(
block_url,
page_url,
headers=self.headers,
timeout=_NOTION_CALL_TIMEOUT,
)
try:
res.raise_for_status()
except Exception as e:
logger.exception(f"Error fetching page - {res.json()}")
raise e
logger.warning(
f"Failed to fetch page, trying database for ID '{page_id}'. Exception: {e}"
)
# Try fetching as a database if page fetch fails, this happens if the page is set to a wiki
# it becomes a database from the notion perspective
return self._fetch_database_as_page(page_id)
return NotionPage(**res.json())
@retry(tries=3, delay=1, backoff=2)
def _fetch_database_as_page(self, database_id: str) -> NotionPage:
"""Attempt to fetch a database as a page."""
logger.debug(f"Fetching database for ID '{database_id}' as a page")
database_url = f"https://api.notion.com/v1/databases/{database_id}"
res = rl_requests.get(
database_url,
headers=self.headers,
timeout=_NOTION_CALL_TIMEOUT,
)
try:
res.raise_for_status()
except Exception as e:
logger.exception(f"Error fetching database as page - {res.json()}")
raise e
database_name = res.json().get("title")
database_name = (
database_name[0].get("text", {}).get("content") if database_name else None
)
return NotionPage(**res.json(), database_name=database_name)
@retry(tries=3, delay=1, backoff=2)
def _fetch_database(
self, database_id: str, cursor: str | None = None
@@ -171,8 +212,75 @@ class NotionConnector(LoadConnector, PollConnector):
raise e
return res.json()
def _read_pages_from_database(self, database_id: str) -> list[str]:
"""Returns a list of all page IDs in the database"""
@staticmethod
def _properties_to_str(properties: dict[str, Any]) -> str:
"""Converts Notion properties to a string"""
def _recurse_properties(inner_dict: dict[str, Any]) -> str | None:
while "type" in inner_dict:
type_name = inner_dict["type"]
inner_dict = inner_dict[type_name]
# If the innermost layer is None, the value is not set
if not inner_dict:
return None
if isinstance(inner_dict, list):
list_properties = [
_recurse_properties(item) for item in inner_dict if item
]
return (
", ".join(
[
list_property
for list_property in list_properties
if list_property
]
)
or None
)
# TODO there may be more types to handle here
if "name" in inner_dict:
return inner_dict["name"]
if "content" in inner_dict:
return inner_dict["content"]
start = inner_dict.get("start")
end = inner_dict.get("end")
if start is not None:
if end is not None:
return f"{start} - {end}"
return start
elif end is not None:
return f"Until {end}"
if "id" in inner_dict:
# This is not useful to index, it's a reference to another Notion object
# and this ID value in plaintext is useless outside of the Notion context
logger.debug("Skipping Notion object id field property")
return None
logger.debug(f"Unreadable property from innermost prop: {inner_dict}")
return None
result = ""
for prop_name, prop in properties.items():
if not prop:
continue
inner_value = _recurse_properties(prop)
# Not a perfect way to format Notion database tables but there's no perfect representation
# since this must be represented as plaintext
if inner_value:
result += f"{prop_name}: {inner_value}\t"
return result
def _read_pages_from_database(
self, database_id: str
) -> tuple[list[NotionBlock], list[str]]:
"""Returns a list of top level blocks and all page IDs in the database"""
result_blocks: list[NotionBlock] = []
result_pages: list[str] = []
cursor = None
while True:
@@ -181,29 +289,34 @@ class NotionConnector(LoadConnector, PollConnector):
for result in data["results"]:
obj_id = result["id"]
obj_type = result["object"]
if obj_type == "page":
logger.debug(
f"Found page with ID '{obj_id}' in database '{database_id}'"
)
result_pages.append(result["id"])
elif obj_type == "database":
logger.debug(
f"Found database with ID '{obj_id}' in database '{database_id}'"
)
result_pages.extend(self._read_pages_from_database(obj_id))
text = self._properties_to_str(result.get("properties", {}))
if text:
result_blocks.append(NotionBlock(id=obj_id, text=text, prefix="\n"))
if self.recursive_index_enabled:
if obj_type == "page":
logger.debug(
f"Found page with ID '{obj_id}' in database '{database_id}'"
)
result_pages.append(result["id"])
elif obj_type == "database":
logger.debug(
f"Found database with ID '{obj_id}' in database '{database_id}'"
)
# The inner contents are ignored at this level
_, child_pages = self._read_pages_from_database(obj_id)
result_pages.extend(child_pages)
if data["next_cursor"] is None:
break
cursor = data["next_cursor"]
return result_pages
return result_blocks, result_pages
def _read_blocks(
self, base_block_id: str
) -> tuple[list[tuple[str, str]], list[str]]:
"""Reads all child blocks for the specified block"""
result_lines: list[tuple[str, str]] = []
def _read_blocks(self, base_block_id: str) -> tuple[list[NotionBlock], list[str]]:
"""Reads all child blocks for the specified block, returns a list of blocks and child page ids"""
result_blocks: list[NotionBlock] = []
child_pages: list[str] = []
cursor = None
while True:
@@ -211,7 +324,7 @@ class NotionConnector(LoadConnector, PollConnector):
# this happens when a block is not shared with the integration
if data is None:
return result_lines, child_pages
return result_blocks, child_pages
for result in data["results"]:
logger.debug(
@@ -255,46 +368,70 @@ class NotionConnector(LoadConnector, PollConnector):
if result["has_children"]:
if result_type == "child_page":
# Child pages will not be included at this top level, it will be a separate document
child_pages.append(result_block_id)
else:
logger.debug(f"Entering sub-block: {result_block_id}")
subblock_result_lines, subblock_child_pages = self._read_blocks(
subblocks, subblock_child_pages = self._read_blocks(
result_block_id
)
logger.debug(f"Finished sub-block: {result_block_id}")
result_lines.extend(subblock_result_lines)
result_blocks.extend(subblocks)
child_pages.extend(subblock_child_pages)
if result_type == "child_database" and self.recursive_index_enabled:
child_pages.extend(self._read_pages_from_database(result_block_id))
if result_type == "child_database":
inner_blocks, inner_child_pages = self._read_pages_from_database(
result_block_id
)
# A database on a page often looks like a table, we need to include it for the contents
# of the page but the children (cells) should be processed as other Documents
result_blocks.extend(inner_blocks)
cur_result_text = "\n".join(cur_result_text_arr)
if cur_result_text:
result_lines.append((cur_result_text, result_block_id))
if self.recursive_index_enabled:
child_pages.extend(inner_child_pages)
if cur_result_text_arr:
new_block = NotionBlock(
id=result_block_id,
text="\n".join(cur_result_text_arr),
prefix="\n",
)
result_blocks.append(new_block)
if data["next_cursor"] is None:
break
cursor = data["next_cursor"]
return result_lines, child_pages
return result_blocks, child_pages
def _read_page_title(self, page: NotionPage) -> str:
def _read_page_title(self, page: NotionPage) -> str | None:
"""Extracts the title from a Notion page"""
page_title = None
if hasattr(page, "database_name") and page.database_name:
return page.database_name
for _, prop in page.properties.items():
if prop["type"] == "title" and len(prop["title"]) > 0:
page_title = " ".join([t["plain_text"] for t in prop["title"]]).strip()
break
if page_title is None:
page_title = f"Untitled Page [{page.id}]"
return page_title
def _read_pages(
self,
pages: list[NotionPage],
) -> Generator[Document, None, None]:
"""Reads pages for rich text content and generates Documents"""
"""Reads pages for rich text content and generates Documents
Note that a page which is turned into a "wiki" becomes a database but both top level pages and top level databases
do not seem to have any properties associated with them.
Pages that are part of a database can have properties which are like the values of the row in the "database" table
in which they exist
This is not clearly outlined in the Notion API docs but it is observable empirically.
https://developers.notion.com/docs/working-with-page-content
"""
all_child_page_ids: list[str] = []
for page in pages:
if page.id in self.indexed_pages:
@@ -304,18 +441,23 @@ class NotionConnector(LoadConnector, PollConnector):
logger.info(f"Reading page with ID '{page.id}', with url {page.url}")
page_blocks, child_page_ids = self._read_blocks(page.id)
all_child_page_ids.extend(child_page_ids)
page_title = self._read_page_title(page)
if not page_blocks:
continue
page_title = (
self._read_page_title(page) or f"Untitled Page with ID {page.id}"
)
yield (
Document(
id=page.id,
# Will add title to the first section later in processing
sections=[Section(link=page.url, text="")]
+ [
sections=[
Section(
link=f"{page.url}#{block_id.replace('-', '')}",
text=block_text,
link=f"{page.url}#{block.id.replace('-', '')}",
text=block.prefix + block.text,
)
for block_text, block_id in page_blocks
for block in page_blocks
],
source=DocumentSource.NOTION,
semantic_identifier=page_title,

View File

@@ -128,6 +128,9 @@ def get_internal_links(
if not href:
continue
# Account for malformed backslashes in URLs
href = href.replace("\\", "/")
if should_ignore_pound and "#" in href:
href = href.split("#")[0]

View File

@@ -0,0 +1,66 @@
from mistune import Markdown # type: ignore
from mistune import Renderer # type: ignore
def format_slack_message(message: str | None) -> str:
renderer = Markdown(renderer=SlackRenderer())
return renderer.render(message)
class SlackRenderer(Renderer):
SPECIALS: dict[str, str] = {"&": "&amp;", "<": "&lt;", ">": "&gt;"}
def escape_special(self, text: str) -> str:
for special, replacement in self.SPECIALS.items():
text = text.replace(special, replacement)
return text
def header(self, text: str, level: int, raw: str | None = None) -> str:
return f"*{text}*\n"
def emphasis(self, text: str) -> str:
return f"_{text}_"
def double_emphasis(self, text: str) -> str:
return f"*{text}*"
def strikethrough(self, text: str) -> str:
return f"~{text}~"
def list(self, body: str, ordered: bool = True) -> str:
lines = body.split("\n")
count = 0
for i, line in enumerate(lines):
if line.startswith("li: "):
count += 1
prefix = f"{count}. " if ordered else ""
lines[i] = f"{prefix}{line[4:]}"
return "\n".join(lines)
def list_item(self, text: str) -> str:
return f"li: {text}\n"
def link(self, link: str, title: str | None, content: str | None) -> str:
escaped_link = self.escape_special(link)
if content:
return f"<{escaped_link}|{content}>"
if title:
return f"<{escaped_link}|{title}>"
return f"<{escaped_link}>"
def image(self, src: str, title: str | None, text: str | None) -> str:
escaped_src = self.escape_special(src)
display_text = title or text
return f"<{escaped_src}|{display_text}>" if display_text else f"<{escaped_src}>"
def codespan(self, text: str) -> str:
return f"`{text}`"
def block_code(self, text: str, lang: str | None) -> str:
return f"```\n{text}\n```\n"
def paragraph(self, text: str) -> str:
return f"{text}\n"
def autolink(self, link: str, is_email: bool) -> str:
return link if is_email else self.link(link, None, None)

View File

@@ -4,9 +4,7 @@ from typing import cast
from slack_sdk import WebClient
from slack_sdk.models.blocks import SectionBlock
from slack_sdk.models.views import View
from slack_sdk.socket_mode import SocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest
from sqlalchemy.orm import Session
from danswer.configs.constants import MessageType
from danswer.configs.constants import SearchFeedbackType
@@ -35,20 +33,22 @@ from danswer.danswerbot.slack.utils import get_channel_name_from_id
from danswer.danswerbot.slack.utils import get_feedback_visibility
from danswer.danswerbot.slack.utils import read_slack_thread
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import TenantSocketModeClient
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import get_session_with_tenant
from danswer.db.feedback import create_chat_message_feedback
from danswer.db.feedback import create_doc_retrieval_feedback
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.utils.logger import setup_logger
logger = setup_logger()
def handle_doc_feedback_button(
req: SocketModeRequest,
client: SocketModeClient,
client: TenantSocketModeClient,
) -> None:
if not (actions := req.payload.get("actions")):
logger.error("Missing actions. Unable to build the source feedback view")
@@ -81,7 +81,7 @@ def handle_doc_feedback_button(
def handle_generate_answer_button(
req: SocketModeRequest,
client: SocketModeClient,
client: TenantSocketModeClient,
) -> None:
channel_id = req.payload["channel"]["id"]
channel_name = req.payload["channel"]["name"]
@@ -116,7 +116,7 @@ def handle_generate_answer_button(
thread_ts=thread_ts,
)
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(client.tenant_id) as db_session:
slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session
)
@@ -136,6 +136,7 @@ def handle_generate_answer_button(
slack_bot_config=slack_bot_config,
receiver_ids=None,
client=client.web_client,
tenant_id=client.tenant_id,
channel=channel_id,
logger=logger,
feedback_reminder_id=None,
@@ -150,12 +151,11 @@ def handle_slack_feedback(
user_id_to_post_confirmation: str,
channel_id_to_post_confirmation: str,
thread_ts_to_post_confirmation: str,
tenant_id: str | None,
) -> None:
engine = get_sqlalchemy_engine()
message_id, doc_id, doc_rank = decompose_action_id(feedback_id)
with Session(engine) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
if feedback_type in [LIKE_BLOCK_ACTION_ID, DISLIKE_BLOCK_ACTION_ID]:
create_chat_message_feedback(
is_positive=feedback_type == LIKE_BLOCK_ACTION_ID,
@@ -232,7 +232,7 @@ def handle_slack_feedback(
def handle_followup_button(
req: SocketModeRequest,
client: SocketModeClient,
client: TenantSocketModeClient,
) -> None:
action_id = None
if actions := req.payload.get("actions"):
@@ -252,7 +252,7 @@ def handle_followup_button(
tag_ids: list[str] = []
group_ids: list[str] = []
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(client.tenant_id) as db_session:
channel_name, is_dm = get_channel_name_from_id(
client=client.web_client, channel_id=channel_id
)
@@ -295,7 +295,7 @@ def handle_followup_button(
def get_clicker_name(
req: SocketModeRequest,
client: SocketModeClient,
client: TenantSocketModeClient,
) -> str:
clicker_name = req.payload.get("user", {}).get("name", "Someone")
clicker_real_name = None
@@ -316,7 +316,7 @@ def get_clicker_name(
def handle_followup_resolved_button(
req: SocketModeRequest,
client: SocketModeClient,
client: TenantSocketModeClient,
immediate: bool = False,
) -> None:
channel_id = req.payload["container"]["channel_id"]

View File

@@ -2,7 +2,6 @@ import datetime
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from sqlalchemy.orm import Session
from danswer.configs.danswerbot_configs import DANSWER_BOT_FEEDBACK_REMINDER
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
@@ -19,7 +18,7 @@ from danswer.danswerbot.slack.utils import fetch_user_ids_from_groups
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import slack_usage_report
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import SlackBotConfig
from danswer.db.users import add_non_web_user_if_not_exists
from danswer.utils.logger import setup_logger
@@ -110,6 +109,7 @@ def handle_message(
slack_bot_config: SlackBotConfig | None,
client: WebClient,
feedback_reminder_id: str | None,
tenant_id: str | None,
) -> bool:
"""Potentially respond to the user message depending on filters and if an answer was generated
@@ -135,7 +135,9 @@ def handle_message(
action = "slack_tag_message"
elif is_bot_dm:
action = "slack_dm_message"
slack_usage_report(action=action, sender_id=sender_id, client=client)
slack_usage_report(
action=action, sender_id=sender_id, client=client, tenant_id=tenant_id
)
document_set_names: list[str] | None = None
persona = slack_bot_config.persona if slack_bot_config else None
@@ -209,7 +211,7 @@ def handle_message(
except SlackApiError as e:
logger.error(f"Was not able to react to user message due to: {e}")
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
if message_info.email:
add_non_web_user_if_not_exists(db_session, message_info.email)
@@ -235,5 +237,6 @@ def handle_message(
channel=channel,
logger=logger,
feedback_reminder_id=feedback_reminder_id,
tenant_id=tenant_id,
)
return issue_with_regular_answer

View File

@@ -5,12 +5,10 @@ from typing import cast
from typing import Optional
from typing import TypeVar
from fastapi import HTTPException
from retry import retry
from slack_sdk import WebClient
from slack_sdk.models.blocks import DividerBlock
from slack_sdk.models.blocks import SectionBlock
from sqlalchemy.orm import Session
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
@@ -28,12 +26,13 @@ from danswer.danswerbot.slack.blocks import build_follow_up_block
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
from danswer.danswerbot.slack.blocks import build_sources_blocks
from danswer.danswerbot.slack.blocks import get_restate_blocks
from danswer.danswerbot.slack.formatting import format_slack_message
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import SlackRateLimiter
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import Persona
from danswer.db.models import SlackBotConfig
from danswer.db.models import SlackBotResponseType
@@ -88,6 +87,7 @@ def handle_regular_answer(
channel: str,
logger: DanswerLoggingAdapter,
feedback_reminder_id: str | None,
tenant_id: str | None,
num_retries: int = DANSWER_BOT_NUM_RETRIES,
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE,
@@ -104,8 +104,7 @@ def handle_regular_answer(
user = None
if message_info.is_bot_dm:
if message_info.email:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
user = get_user_by_email(message_info.email, db_session)
document_set_names: list[str] | None = None
@@ -152,14 +151,10 @@ def handle_regular_answer(
max_document_tokens: int | None = None
max_history_tokens: int | None = None
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
if len(new_message_request.messages) > 1:
if new_message_request.persona_config:
raise HTTPException(
status_code=403,
detail="Slack bot does not support persona config",
)
raise RuntimeError("Slack bot does not support persona config")
elif new_message_request.persona_id is not None:
persona = cast(
Persona,
@@ -170,6 +165,10 @@ def handle_regular_answer(
get_editable=False,
),
)
else:
raise RuntimeError(
"No persona id provided, this should never happen."
)
llm, _ = get_llms_for_persona(persona)
@@ -246,7 +245,7 @@ def handle_regular_answer(
)
# Always apply reranking settings if it exists, this is the non-streaming flow
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
saved_search_settings = get_current_search_settings(db_session)
# This includes throwing out answer via reflexion
@@ -413,10 +412,11 @@ def handle_regular_answer(
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
answer_blocks = build_qa_response_blocks(
message_id=answer.chat_message_id,
answer=answer.answer,
answer=formatted_answer,
quotes=answer.quotes.quotes if answer.quotes else None,
source_filters=retrieval_info.applied_source_filters,
time_cutoff=retrieval_info.applied_time_cutoff,

View File

@@ -4,11 +4,10 @@ from typing import Any
from typing import cast
from slack_sdk import WebClient
from slack_sdk.socket_mode import SocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import get_all_tenant_ids
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
@@ -47,7 +46,8 @@ from danswer.danswerbot.slack.utils import read_slack_thread
from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
from danswer.danswerbot.slack.utils import rephrase_slack_message
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.db.engine import get_sqlalchemy_engine
from danswer.danswerbot.slack.utils import TenantSocketModeClient
from danswer.db.engine import get_session_with_tenant
from danswer.db.search_settings import get_current_search_settings
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
@@ -80,7 +80,7 @@ _SLACK_GREETINGS_TO_IGNORE = {
_OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool:
def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -> bool:
"""True to keep going, False to ignore this Slack request"""
if req.type == "events_api":
# Verify channel is valid
@@ -153,8 +153,7 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
client=client.web_client, channel_id=channel
)
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
with get_session_with_tenant(client.tenant_id) as db_session:
slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session
)
@@ -221,7 +220,7 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
return True
def process_feedback(req: SocketModeRequest, client: SocketModeClient) -> None:
def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
if actions := req.payload.get("actions"):
action = cast(dict[str, Any], actions[0])
feedback_type = cast(str, action.get("action_id"))
@@ -243,6 +242,7 @@ def process_feedback(req: SocketModeRequest, client: SocketModeClient) -> None:
user_id_to_post_confirmation=user_id,
channel_id_to_post_confirmation=channel_id,
thread_ts_to_post_confirmation=thread_ts,
tenant_id=client.tenant_id,
)
query_event_id, _, _ = decompose_action_id(feedback_id)
@@ -250,7 +250,7 @@ def process_feedback(req: SocketModeRequest, client: SocketModeClient) -> None:
def build_request_details(
req: SocketModeRequest, client: SocketModeClient
req: SocketModeRequest, client: TenantSocketModeClient
) -> SlackMessageInfo:
if req.type == "events_api":
event = cast(dict[str, Any], req.payload["event"])
@@ -329,7 +329,7 @@ def build_request_details(
def apologize_for_fail(
details: SlackMessageInfo,
client: SocketModeClient,
client: TenantSocketModeClient,
) -> None:
respond_in_thread(
client=client.web_client,
@@ -341,7 +341,7 @@ def apologize_for_fail(
def process_message(
req: SocketModeRequest,
client: SocketModeClient,
client: TenantSocketModeClient,
respond_every_channel: bool = DANSWER_BOT_RESPOND_EVERY_CHANNEL,
notify_no_answer: bool = NOTIFY_SLACKBOT_NO_ANSWER,
) -> None:
@@ -357,8 +357,7 @@ def process_message(
client=client.web_client, channel_id=channel
)
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
with get_session_with_tenant(client.tenant_id) as db_session:
slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session
)
@@ -390,6 +389,7 @@ def process_message(
slack_bot_config=slack_bot_config,
client=client.web_client,
feedback_reminder_id=feedback_reminder_id,
tenant_id=client.tenant_id,
)
if failed:
@@ -404,12 +404,12 @@ def process_message(
apologize_for_fail(details, client)
def acknowledge_message(req: SocketModeRequest, client: SocketModeClient) -> None:
def acknowledge_message(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
response = SocketModeResponse(envelope_id=req.envelope_id)
client.send_socket_mode_response(response)
def action_routing(req: SocketModeRequest, client: SocketModeClient) -> None:
def action_routing(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
if actions := req.payload.get("actions"):
action = cast(dict[str, Any], actions[0])
@@ -429,13 +429,13 @@ def action_routing(req: SocketModeRequest, client: SocketModeClient) -> None:
return handle_generate_answer_button(req, client)
def view_routing(req: SocketModeRequest, client: SocketModeClient) -> None:
def view_routing(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
if view := req.payload.get("view"):
if view["callback_id"] == VIEW_DOC_FEEDBACK_ID:
return process_feedback(req, client)
def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> None:
def process_slack_event(client: TenantSocketModeClient, req: SocketModeRequest) -> None:
# Always respond right away, if Slack doesn't receive these frequently enough
# it will assume the Bot is DEAD!!! :(
acknowledge_message(req, client)
@@ -453,21 +453,24 @@ def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> Non
logger.error(f"Slack request payload: {req.payload}")
def _get_socket_client(slack_bot_tokens: SlackBotTokens) -> SocketModeClient:
def _get_socket_client(
slack_bot_tokens: SlackBotTokens, tenant_id: str | None
) -> TenantSocketModeClient:
# For more info on how to set this up, checkout the docs:
# https://docs.danswer.dev/slack_bot_setup
return SocketModeClient(
return TenantSocketModeClient(
# This app-level token will be used only for establishing a connection
app_token=slack_bot_tokens.app_token,
web_client=WebClient(token=slack_bot_tokens.bot_token),
tenant_id=tenant_id,
)
def _initialize_socket_client(socket_client: SocketModeClient) -> None:
def _initialize_socket_client(socket_client: TenantSocketModeClient) -> None:
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
# Establish a WebSocket connection to the Socket Mode servers
logger.notice("Listening for messages from Slack...")
logger.notice(f"Listening for messages from Slack {socket_client.tenant_id }...")
socket_client.connect()
@@ -481,8 +484,8 @@ def _initialize_socket_client(socket_client: SocketModeClient) -> None:
# NOTE: we are using Web Sockets so that you can run this from within a firewalled VPC
# without issue.
if __name__ == "__main__":
slack_bot_tokens: SlackBotTokens | None = None
socket_client: SocketModeClient | None = None
slack_bot_tokens: dict[str | None, SlackBotTokens] = {}
socket_clients: dict[str | None, TenantSocketModeClient] = {}
set_is_ee_based_on_env_variable()
@@ -491,46 +494,59 @@ if __name__ == "__main__":
while True:
try:
latest_slack_bot_tokens = fetch_tokens()
tenant_ids = get_all_tenant_ids() # Function to retrieve all tenant IDs
if latest_slack_bot_tokens != slack_bot_tokens:
if slack_bot_tokens is not None:
logger.notice("Slack Bot tokens have changed - reconnecting")
else:
# This happens on the very first time the listener process comes up
# or the tokens have updated (set up for the first time)
with Session(get_sqlalchemy_engine()) as db_session:
search_settings = get_current_search_settings(db_session)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
for tenant_id in tenant_ids:
with get_session_with_tenant(tenant_id) as db_session:
try:
latest_slack_bot_tokens = fetch_tokens()
warm_up_bi_encoder(
embedding_model=embedding_model,
)
if (
tenant_id not in slack_bot_tokens
or latest_slack_bot_tokens != slack_bot_tokens[tenant_id]
):
if tenant_id in slack_bot_tokens:
logger.notice(
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
)
else:
# Initial setup for this tenant
search_settings = get_current_search_settings(
db_session
)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
slack_bot_tokens = latest_slack_bot_tokens
# potentially may cause a message to be dropped, but it is complicated
# to avoid + (1) if the user is changing tokens, they are likely okay with some
# "migration downtime" and (2) if a single message is lost it is okay
# as this should be a very rare occurrence
if socket_client:
socket_client.close()
slack_bot_tokens[tenant_id] = latest_slack_bot_tokens
socket_client = _get_socket_client(slack_bot_tokens)
_initialize_socket_client(socket_client)
# potentially may cause a message to be dropped, but it is complicated
# to avoid + (1) if the user is changing tokens, they are likely okay with some
# "migration downtime" and (2) if a single message is lost it is okay
# as this should be a very rare occurrence
if tenant_id in socket_clients:
socket_clients[tenant_id].close()
# Let the handlers run in the background + re-check for token updates every 60 seconds
socket_client = _get_socket_client(
latest_slack_bot_tokens, tenant_id
)
_initialize_socket_client(socket_client)
socket_clients[tenant_id] = socket_client
except KvKeyNotFoundError:
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if tenant_id in socket_clients:
socket_clients[tenant_id].disconnect()
del socket_clients[tenant_id]
del slack_bot_tokens[tenant_id]
# Wait before checking for updates
Event().wait(timeout=60)
except KvKeyNotFoundError:
# try again every 30 seconds. This is needed since the user may add tokens
# via the UI at any point in the programs lifecycle - if we just allow it to
# fail, then the user will need to restart the containers after adding tokens
logger.debug(
"Missing Slack Bot tokens - waiting 60 seconds and trying again"
)
if socket_client:
socket_client.disconnect()
except Exception:
logger.exception("An error occurred outside of main event loop")
time.sleep(60)

View File

@@ -12,7 +12,7 @@ from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.models.blocks import Block
from slack_sdk.models.metadata import Metadata
from sqlalchemy.orm import Session
from slack_sdk.socket_mode import SocketModeClient
from danswer.configs.app_configs import DISABLE_TELEMETRY
from danswer.configs.constants import ID_SEPARATOR
@@ -31,7 +31,7 @@ from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import SlackTextCleaner
from danswer.danswerbot.slack.constants import FeedbackVisibility
from danswer.danswerbot.slack.tokens import fetch_tokens
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import get_session_with_tenant
from danswer.db.users import get_user_by_email
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llms
@@ -489,7 +489,9 @@ def read_slack_thread(
return thread_messages
def slack_usage_report(action: str, sender_id: str | None, client: WebClient) -> None:
def slack_usage_report(
action: str, sender_id: str | None, client: WebClient, tenant_id: str | None
) -> None:
if DISABLE_TELEMETRY:
return
@@ -501,7 +503,7 @@ def slack_usage_report(action: str, sender_id: str | None, client: WebClient) ->
logger.warning("Unable to find sender email")
if sender_email is not None:
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
danswer_user = get_user_by_email(email=sender_email, db_session=db_session)
optional_telemetry(
@@ -577,3 +579,9 @@ def get_feedback_visibility() -> FeedbackVisibility:
return FeedbackVisibility(DANSWER_BOT_FEEDBACK_VISIBILITY.lower())
except ValueError:
return FeedbackVisibility.PRIVATE
class TenantSocketModeClient(SocketModeClient):
def __init__(self, tenant_id: str | None, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.tenant_id = tenant_id

View File

@@ -10,10 +10,12 @@ from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyAccessTokenDataba
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import Session
from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserRole
from danswer.db.engine import get_async_session
from danswer.db.engine import get_sqlalchemy_async_engine
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User
@@ -33,10 +35,20 @@ def get_default_admin_user_emails() -> list[str]:
return get_default_admin_user_emails_fn()
def get_total_users(db_session: Session) -> int:
"""
Returns the total number of users in the system.
This is the sum of users and invited users.
"""
user_count = db_session.query(User).count()
invited_users = len(get_invited_users())
return user_count + invited_users
async def get_user_count() -> int:
async with AsyncSession(get_sqlalchemy_async_engine()) as asession:
async with get_async_session_with_tenant() as session:
stmt = select(func.count(User.id))
result = await asession.execute(stmt)
result = await session.execute(stmt)
user_count = result.scalar()
if user_count is None:
raise RuntimeError("Was not able to fetch the user count.")

View File

@@ -43,7 +43,7 @@ logger = setup_logger()
def get_chat_session_by_id(
chat_session_id: int,
chat_session_id: UUID,
user_id: UUID | None,
db_session: Session,
include_deleted: bool = False,
@@ -87,9 +87,9 @@ def get_chat_sessions_by_slack_thread_id(
def get_valid_messages_from_query_sessions(
chat_session_ids: list[int],
chat_session_ids: list[UUID],
db_session: Session,
) -> dict[int, str]:
) -> dict[UUID, str]:
user_message_subquery = (
select(
ChatMessage.chat_session_id, func.min(ChatMessage.id).label("user_msg_id")
@@ -196,7 +196,7 @@ def delete_orphaned_search_docs(db_session: Session) -> None:
def delete_messages_and_files_from_chat_session(
chat_session_id: int, db_session: Session
chat_session_id: UUID, db_session: Session
) -> None:
# Select messages older than cutoff_time with files
messages_with_files = db_session.execute(
@@ -253,7 +253,7 @@ def create_chat_session(
def update_chat_session(
db_session: Session,
user_id: UUID | None,
chat_session_id: int,
chat_session_id: UUID,
description: str | None = None,
sharing_status: ChatSessionSharedStatus | None = None,
) -> ChatSession:
@@ -276,7 +276,7 @@ def update_chat_session(
def delete_chat_session(
user_id: UUID | None,
chat_session_id: int,
chat_session_id: UUID,
db_session: Session,
hard_delete: bool = HARD_DELETE_CHATS,
) -> None:
@@ -337,7 +337,7 @@ def get_chat_message(
def get_chat_messages_by_sessions(
chat_session_ids: list[int],
chat_session_ids: list[UUID],
user_id: UUID | None,
db_session: Session,
skip_permission_check: bool = False,
@@ -370,7 +370,7 @@ def get_search_docs_for_chat_message(
def get_chat_messages_by_session(
chat_session_id: int,
chat_session_id: UUID,
user_id: UUID | None,
db_session: Session,
skip_permission_check: bool = False,
@@ -397,7 +397,7 @@ def get_chat_messages_by_session(
def get_or_create_root_message(
chat_session_id: int,
chat_session_id: UUID,
db_session: Session,
) -> ChatMessage:
try:
@@ -433,7 +433,7 @@ def get_or_create_root_message(
def reserve_message_id(
db_session: Session,
chat_session_id: int,
chat_session_id: UUID,
parent_message: int,
message_type: MessageType,
) -> int:
@@ -460,7 +460,7 @@ def reserve_message_id(
def create_new_chat_message(
chat_session_id: int,
chat_session_id: UUID,
parent_message: ChatMessage,
message: str,
prompt_id: int | None,

View File

@@ -248,7 +248,7 @@ def create_initial_default_connector(db_session: Session) -> None:
logger.warning(
"Default connector does not have expected values. Updating to proper state."
)
# Ensure default connector has correct valuesg
# Ensure default connector has correct values
default_connector.source = DocumentSource.INGESTION_API
default_connector.input_type = InputType.LOAD_STATE
default_connector.refresh_freq = None

View File

@@ -390,6 +390,7 @@ def add_credential_to_connector(
)
db_session.add(association)
db_session.flush() # make sure the association has an id
db_session.refresh(association)
if groups and access_type != AccessType.SYNC:
_relate_groups_to_cc_pair__no_commit(

View File

@@ -241,8 +241,7 @@ def create_credential(
curator_public=credential_data.curator_public,
)
db_session.add(credential)
db_session.flush() # This ensures the credential gets an ID
db_session.flush() # This ensures the credential gets an IDcredentials
_relate_credential_to_user_groups__no_commit(
db_session=db_session,
credential_id=credential.id,

View File

@@ -398,7 +398,7 @@ def mark_document_set_as_to_be_deleted(
def delete_document_set_cc_pair_relationship__no_commit(
connector_id: int, credential_id: int, db_session: Session
) -> None:
) -> int:
"""Deletes all rows from DocumentSet__ConnectorCredentialPair where the
connector_credential_pair_id matches the given cc_pair_id."""
delete_stmt = delete(DocumentSet__ConnectorCredentialPair).where(
@@ -409,7 +409,8 @@ def delete_document_set_cc_pair_relationship__no_commit(
== ConnectorCredentialPair.id,
)
)
db_session.execute(delete_stmt)
result = db_session.execute(delete_stmt)
return result.rowcount # type: ignore
def fetch_document_sets(

View File

@@ -1,16 +1,16 @@
import contextlib
import contextvars
import re
import threading
import time
from collections.abc import AsyncGenerator
from collections.abc import Generator
from contextlib import asynccontextmanager
from contextlib import contextmanager
from datetime import datetime
from typing import Any
from typing import ContextManager
import jwt
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import event
@@ -39,7 +39,7 @@ from danswer.configs.app_configs import SECRET_JWT_KEY
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from danswer.utils.logger import setup_logger
from shared_configs.configs import current_tenant_id
logger = setup_logger()
@@ -230,18 +230,8 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
return _ASYNC_ENGINE
# Context variable to store the current tenant ID
# This allows us to maintain tenant-specific context throughout the request lifecycle
# The default value is set to POSTGRES_DEFAULT_SCHEMA for non-multi-tenant setups
# This context variable works in both synchronous and asynchronous contexts
# In async code, it's automatically carried across coroutines
# In sync code, it's managed per thread
current_tenant_id = contextvars.ContextVar(
"current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA
)
# Dependency to get the current tenant ID and set the context variable
# Dependency to get the current tenant ID
# If no token is present, uses the default schema for this use case
def get_current_tenant_id(request: Request) -> str:
"""Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable."""
if not MULTI_TENANT:
@@ -251,32 +241,31 @@ def get_current_tenant_id(request: Request) -> str:
token = request.cookies.get("tenant_details")
if not token:
current_value = current_tenant_id.get()
# If no token is present, use the default schema or handle accordingly
tenant_id = POSTGRES_DEFAULT_SCHEMA
current_tenant_id.set(tenant_id)
return tenant_id
return current_value
try:
payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"])
tenant_id = payload.get("tenant_id")
if not tenant_id:
raise HTTPException(
status_code=400, detail="Invalid token: tenant_id missing"
)
return current_tenant_id.get()
if not is_valid_schema_name(tenant_id):
raise ValueError("Invalid tenant ID format")
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
current_tenant_id.set(tenant_id)
return tenant_id
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token format")
except ValueError as e:
# Let the 400 error bubble up
raise HTTPException(status_code=400, detail=str(e))
except Exception:
return current_tenant_id.get()
except Exception as e:
logger.error(f"Unexpected error in get_current_tenant_id: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
def get_session_with_tenant(tenant_id: str | None = None) -> Session:
@asynccontextmanager
async def get_async_session_with_tenant(
tenant_id: str | None = None,
) -> AsyncGenerator[AsyncSession, None]:
if tenant_id is None:
tenant_id = current_tenant_id.get()
@@ -284,21 +273,94 @@ def get_session_with_tenant(tenant_id: str | None = None) -> Session:
logger.error(f"Invalid tenant ID: {tenant_id}")
raise Exception("Invalid tenant ID")
engine = SqlEngine.get_engine()
session = Session(engine, expire_on_commit=False)
engine = get_sqlalchemy_async_engine()
async_session_factory = sessionmaker(
bind=engine, expire_on_commit=False, class_=AsyncSession
) # type: ignore
@event.listens_for(session, "after_begin")
def set_search_path(session: Session, transaction: Any, connection: Any) -> None:
connection.execute(text("SET search_path TO :schema"), {"schema": tenant_id})
return session
async with async_session_factory() as session:
try:
# Set the search_path to the tenant's schema
await session.execute(text(f'SET search_path = "{tenant_id}"'))
except Exception as e:
logger.error(f"Error setting search_path: {str(e)}")
# You can choose to re-raise the exception or handle it
# Here, we'll re-raise to prevent proceeding with an incorrect session
raise
else:
yield session
def get_session(
tenant_id: str = Depends(get_current_tenant_id),
@contextmanager
def get_session_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
"""Generate a database session with the appropriate tenant schema set."""
"""Generate a database session bound to a connection with the appropriate tenant schema set."""
engine = get_sqlalchemy_engine()
if tenant_id is None:
tenant_id = current_tenant_id.get()
else:
current_tenant_id.set(tenant_id)
event.listen(engine, "checkout", set_search_path_on_checkout)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Establish a raw connection
with engine.connect() as connection:
# Access the raw DBAPI connection and set the search_path
dbapi_connection = connection.connection
# Set the search_path outside of any transaction
cursor = dbapi_connection.cursor()
try:
cursor.execute(f'SET search_path = "{tenant_id}"')
finally:
cursor.close()
# Bind the session to the connection
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
finally:
# Reset search_path to default after the session is used
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()
def set_search_path_on_checkout(
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
) -> None:
tenant_id = current_tenant_id.get()
if tenant_id and is_valid_schema_name(tenant_id):
with dbapi_conn.cursor() as cursor:
cursor.execute(f'SET search_path TO "{tenant_id}"')
logger.debug(
f"Set search_path to {tenant_id} for connection {connection_record}"
)
def get_session_generator_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
with get_session_with_tenant(tenant_id) as session:
yield session
def get_session() -> Generator[Session, None, None]:
"""Generate a database session with the appropriate tenant schema set."""
tenant_id = current_tenant_id.get()
if tenant_id == "public" and MULTI_TENANT:
raise HTTPException(status_code=401, detail="User must authenticate")
engine = get_sqlalchemy_engine()
with Session(engine, expire_on_commit=False) as session:
if MULTI_TENANT:
if not is_valid_schema_name(tenant_id):
@@ -308,10 +370,9 @@ def get_session(
yield session
async def get_async_session(
tenant_id: str = Depends(get_current_tenant_id),
) -> AsyncGenerator[AsyncSession, None]:
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
"""Generate an async database session with the appropriate tenant schema set."""
tenant_id = current_tenant_id.get()
engine = get_sqlalchemy_async_engine()
async with AsyncSession(engine, expire_on_commit=False) as async_session:
if MULTI_TENANT:
@@ -324,7 +385,7 @@ async def get_async_session(
def get_session_context_manager() -> ContextManager[Session]:
"""Context manager for database sessions."""
return contextlib.contextmanager(get_session)()
return contextlib.contextmanager(get_session_generator_with_tenant)()
def get_session_factory() -> sessionmaker[Session]:

View File

@@ -1,4 +1,6 @@
from collections.abc import Sequence
from datetime import datetime
from datetime import timezone
from sqlalchemy import and_
from sqlalchemy import delete
@@ -19,8 +21,6 @@ from danswer.db.models import SearchSettings
from danswer.server.documents.models import ConnectorCredentialPair
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
logger = setup_logger()
@@ -66,7 +66,7 @@ def create_index_attempt(
return new_attempt.id
def get_inprogress_index_attempts(
def get_in_progress_index_attempts(
connector_id: int | None,
db_session: Session,
) -> list[IndexAttempt]:
@@ -81,13 +81,15 @@ def get_inprogress_index_attempts(
return list(incomplete_attempts.all())
def get_not_started_index_attempts(db_session: Session) -> list[IndexAttempt]:
def get_all_index_attempts_by_status(
status: IndexingStatus, db_session: Session
) -> list[IndexAttempt]:
"""This eagerly loads the connector and credential so that the db_session can be expired
before running long-living indexing jobs, which causes increasing memory usage.
Results are ordered by time_created (oldest to newest)."""
stmt = select(IndexAttempt)
stmt = stmt.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
stmt = stmt.where(IndexAttempt.status == status)
stmt = stmt.order_by(IndexAttempt.time_created)
stmt = stmt.options(
joinedload(IndexAttempt.connector_credential_pair).joinedload(
@@ -101,31 +103,92 @@ def get_not_started_index_attempts(db_session: Session) -> list[IndexAttempt]:
return list(new_attempts.all())
def transition_attempt_to_in_progress(
index_attempt_id: int,
db_session: Session,
) -> IndexAttempt:
"""Locks the row when we try to update"""
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt_id)
.with_for_update()
).scalar_one()
if attempt is None:
raise RuntimeError(
f"Unable to find IndexAttempt for ID '{index_attempt_id}'"
)
if attempt.status != IndexingStatus.NOT_STARTED:
raise RuntimeError(
f"Indexing attempt with ID '{index_attempt_id}' is not in NOT_STARTED status. "
f"Current status is '{attempt.status}'."
)
attempt.status = IndexingStatus.IN_PROGRESS
attempt.time_started = attempt.time_started or func.now() # type: ignore
db_session.commit()
return attempt
except Exception:
db_session.rollback()
logger.exception("transition_attempt_to_in_progress exceptioned.")
raise
def mark_attempt_in_progress(
index_attempt: IndexAttempt,
db_session: Session,
) -> None:
index_attempt.status = IndexingStatus.IN_PROGRESS
index_attempt.time_started = index_attempt.time_started or func.now() # type: ignore
db_session.commit()
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt.id)
.with_for_update()
).scalar_one()
attempt.status = IndexingStatus.IN_PROGRESS
attempt.time_started = index_attempt.time_started or func.now() # type: ignore
db_session.commit()
except Exception:
db_session.rollback()
raise
def mark_attempt_succeeded(
index_attempt: IndexAttempt,
db_session: Session,
) -> None:
index_attempt.status = IndexingStatus.SUCCESS
db_session.add(index_attempt)
db_session.commit()
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt.id)
.with_for_update()
).scalar_one()
attempt.status = IndexingStatus.SUCCESS
db_session.commit()
except Exception:
db_session.rollback()
raise
def mark_attempt_partially_succeeded(
index_attempt: IndexAttempt,
db_session: Session,
) -> None:
index_attempt.status = IndexingStatus.COMPLETED_WITH_ERRORS
db_session.add(index_attempt)
db_session.commit()
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt.id)
.with_for_update()
).scalar_one()
attempt.status = IndexingStatus.COMPLETED_WITH_ERRORS
db_session.commit()
except Exception:
db_session.rollback()
raise
def mark_attempt_failed(
@@ -134,14 +197,22 @@ def mark_attempt_failed(
failure_reason: str = "Unknown",
full_exception_trace: str | None = None,
) -> None:
index_attempt.status = IndexingStatus.FAILED
index_attempt.error_msg = failure_reason
index_attempt.full_exception_trace = full_exception_trace
db_session.add(index_attempt)
db_session.commit()
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt.id)
.with_for_update()
).scalar_one()
source = index_attempt.connector_credential_pair.connector.source
optional_telemetry(record_type=RecordType.FAILURE, data={"connector": source})
if not attempt.time_started:
attempt.time_started = datetime.now(timezone.utc)
attempt.status = IndexingStatus.FAILED
attempt.error_msg = failure_reason
attempt.full_exception_trace = full_exception_trace
db_session.commit()
except Exception:
db_session.rollback()
raise
def update_docs_indexed(
@@ -435,14 +506,13 @@ def cancel_indexing_attempts_for_ccpair(
db_session.execute(stmt)
db_session.commit()
def cancel_indexing_attempts_past_model(
db_session: Session,
) -> None:
"""Stops all indexing attempts that are in progress or not started for
any embedding model that not present/future"""
db_session.execute(
update(IndexAttempt)
.where(
@@ -455,8 +525,6 @@ def cancel_indexing_attempts_past_model(
.values(status=IndexingStatus.FAILED)
)
db_session.commit()
def count_unique_cc_pairs_with_successful_index_attempts(
search_settings_id: int | None,

View File

@@ -83,6 +83,7 @@ def upsert_llm_provider(
existing_llm_provider.model_names = llm_provider.model_names
existing_llm_provider.is_public = llm_provider.is_public
existing_llm_provider.display_model_names = llm_provider.display_model_names
existing_llm_provider.deployment_name = llm_provider.deployment_name
if not existing_llm_provider.id:
# If its not already in the db, we need to generate an ID by flushing

View File

@@ -5,9 +5,12 @@ from typing import Any
from typing import Literal
from typing import NotRequired
from typing import Optional
from uuid import uuid4
from typing_extensions import TypedDict # noreorder
from uuid import UUID
from sqlalchemy.dialects.postgresql import UUID as PGUUID
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseOAuthAccountTableUUID
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTableUUID
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID
@@ -57,6 +60,7 @@ from danswer.llm.override_models import PromptOverride
from danswer.search.enums import RecencyBiasSetting
from danswer.utils.encryption import decrypt_bytes_to_string
from danswer.utils.encryption import encrypt_string_to_bytes
from danswer.utils.headers import HeaderItemDict
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import RerankerProvider
@@ -231,6 +235,9 @@ class Notification(Base):
first_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
user: Mapped[User] = relationship("User", back_populates="notifications")
additional_data: Mapped[dict | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
"""
@@ -615,6 +622,7 @@ class SearchSettings(Base):
normalize: Mapped[bool] = mapped_column(Boolean)
query_prefix: Mapped[str | None] = mapped_column(String, nullable=True)
passage_prefix: Mapped[str | None] = mapped_column(String, nullable=True)
status: Mapped[IndexModelStatus] = mapped_column(
Enum(IndexModelStatus, native_enum=False)
)
@@ -670,6 +678,20 @@ class SearchSettings(Base):
return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\
cloud_provider='{self.cloud_provider.provider_type if self.cloud_provider else 'None'}')>"
@property
def api_version(self) -> str | None:
return (
self.cloud_provider.api_version if self.cloud_provider is not None else None
)
@property
def deployment_name(self) -> str | None:
return (
self.cloud_provider.deployment_name
if self.cloud_provider is not None
else None
)
@property
def api_url(self) -> str | None:
return self.cloud_provider.api_url if self.cloud_provider is not None else None
@@ -905,7 +927,9 @@ class ToolCall(Base):
class ChatSession(Base):
__tablename__ = "chat_session"
id: Mapped[int] = mapped_column(primary_key=True)
id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True), primary_key=True, default=uuid4
)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
@@ -975,7 +999,9 @@ class ChatMessage(Base):
__tablename__ = "chat_message"
id: Mapped[int] = mapped_column(primary_key=True)
chat_session_id: Mapped[int] = mapped_column(ForeignKey("chat_session.id"))
chat_session_id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True), ForeignKey("chat_session.id")
)
alternate_assistant_id = mapped_column(
Integer, ForeignKey("persona.id"), nullable=True
@@ -1143,6 +1169,8 @@ class LLMProvider(Base):
postgresql.ARRAY(String), nullable=True
)
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
# should only be set for a single provider
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
# EE only
@@ -1162,6 +1190,9 @@ class CloudEmbeddingProvider(Base):
)
api_url: Mapped[str | None] = mapped_column(String, nullable=True)
api_key: Mapped[str | None] = mapped_column(EncryptedString())
api_version: Mapped[str | None] = mapped_column(String, nullable=True)
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
search_settings: Mapped[list["SearchSettings"]] = relationship(
"SearchSettings",
back_populates="cloud_provider",
@@ -1261,7 +1292,7 @@ class Tool(Base):
openapi_schema: Mapped[dict[str, Any] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
custom_headers: Mapped[list[dict[str, str]] | None] = mapped_column(
custom_headers: Mapped[list[HeaderItemDict] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
# user who created / owns the tool. Will be None for built-in tools.
@@ -1761,3 +1792,23 @@ class UsageReport(Base):
requestor = relationship("User")
file = relationship("PGFileStore")
"""
Multi-tenancy related tables
"""
class PublicBase(DeclarativeBase):
__abstract__ = True
class UserTenantMapping(Base):
__tablename__ = "user_tenant_mapping"
__table_args__ = (
UniqueConstraint("email", "tenant_id", name="uq_user_tenant"),
{"schema": "public"},
)
email: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
tenant_id: Mapped[str] = mapped_column(String, nullable=False)

View File

@@ -1,23 +1,47 @@
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.sql import func
from danswer.auth.schemas import UserRole
from danswer.configs.constants import NotificationType
from danswer.db.models import Notification
from danswer.db.models import User
def create_notification(
user: User | None,
user_id: UUID | None,
notif_type: NotificationType,
db_session: Session,
additional_data: dict | None = None,
) -> Notification:
# Check if an undismissed notification of the same type and data exists
existing_notification = (
db_session.query(Notification)
.filter_by(
user_id=user_id,
notif_type=notif_type,
dismissed=False,
)
.filter(Notification.additional_data == additional_data)
.first()
)
if existing_notification:
# Update the last_shown timestamp
existing_notification.last_shown = func.now()
db_session.commit()
return existing_notification
# Create a new notification if none exists
notification = Notification(
user_id=user.id if user else None,
user_id=user_id,
notif_type=notif_type,
dismissed=False,
last_shown=func.now(),
first_shown=func.now(),
additional_data=additional_data,
)
db_session.add(notification)
db_session.commit()
@@ -31,7 +55,9 @@ def get_notification_by_id(
notif = db_session.get(Notification, notification_id)
if not notif:
raise ValueError(f"No notification found with id {notification_id}")
if notif.user_id != user_id:
if notif.user_id != user_id and not (
notif.user_id is None and user.role == UserRole.ADMIN
):
raise PermissionError(
f"User {user_id} is not authorized to access notification {notification_id}"
)

View File

@@ -12,7 +12,7 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import get_session_with_tenant
from danswer.db.llm import fetch_embedding_provider
from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import IndexAttempt
@@ -152,7 +152,7 @@ def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
if db_session is None:
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant() as db_session:
search_settings = get_current_search_settings(db_session)
else:
search_settings = get_current_search_settings(db_session)

View File

@@ -1,5 +1,6 @@
from sqlalchemy.orm import Session
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.constants import KV_REINDEX_KEY
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.connector_credential_pair import resync_cc_pair
@@ -8,16 +9,18 @@ from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import (
count_unique_cc_pairs_with_successful_index_attempts,
)
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.search_settings import update_search_settings_status
from danswer.key_value_store.factory import get_kv_store
from danswer.utils.logger import setup_logger
logger = setup_logger()
def check_index_swap(db_session: Session) -> None:
def check_index_swap(db_session: Session) -> SearchSettings | None:
"""Get count of cc-pairs and count of successful index_attempts for the
new model grouped by connector + credential, if it's the same, then assume
new index is done building. If so, swap the indices and expire the old one."""
@@ -27,7 +30,7 @@ def check_index_swap(db_session: Session) -> None:
search_settings = get_secondary_search_settings(db_session)
if not search_settings:
return
return None
unique_cc_indexings = count_unique_cc_pairs_with_successful_index_attempts(
search_settings_id=search_settings.id, db_session=db_session
@@ -63,3 +66,7 @@ def check_index_swap(db_session: Session) -> None:
# Recount aggregates
for cc_pair in all_cc_pairs:
resync_cc_pair(cc_pair, db_session=db_session)
if MULTI_TENANT:
return now_old_search_settings
return None

View File

@@ -1,4 +1,5 @@
from typing import Any
from typing import cast
from uuid import UUID
from sqlalchemy import select
@@ -6,6 +7,7 @@ from sqlalchemy.orm import Session
from danswer.db.models import Tool
from danswer.server.features.tool.models import Header
from danswer.utils.headers import HeaderItemDict
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -67,7 +69,9 @@ def update_tool(
if user_id is not None:
tool.user_id = user_id
if custom_headers is not None:
tool.custom_headers = [header.dict() for header in custom_headers]
tool.custom_headers = [
cast(HeaderItemDict, header.model_dump()) for header in custom_headers
]
db_session.commit()
return tool

View File

@@ -1,5 +1,6 @@
from sqlalchemy.orm import Session
from danswer.configs.app_configs import MULTI_TENANT
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.vespa.index import VespaIndex
@@ -14,7 +15,9 @@ def get_default_document_index(
index both need to be updated, updates are applied to both indices"""
# Currently only supporting Vespa
return VespaIndex(
index_name=primary_index_name, secondary_index_name=secondary_index_name
index_name=primary_index_name,
secondary_index_name=secondary_index_name,
multitenant=MULTI_TENANT,
)

View File

@@ -127,6 +127,17 @@ class Verifiable(abc.ABC):
"""
raise NotImplementedError
@staticmethod
@abc.abstractmethod
def register_multitenant_indices(
indices: list[str],
embedding_dims: list[int],
) -> None:
"""
Register multitenant indices with the document index.
"""
raise NotImplementedError
class Indexable(abc.ABC):
"""
@@ -172,7 +183,7 @@ class Deletable(abc.ABC):
"""
@abc.abstractmethod
def delete_single(self, doc_id: str) -> None:
def delete_single(self, doc_id: str) -> int:
"""
Given a single document id, hard delete it from the document index
@@ -203,7 +214,7 @@ class Updatable(abc.ABC):
"""
@abc.abstractmethod
def update_single(self, doc_id: str, fields: VespaDocumentFields) -> None:
def update_single(self, doc_id: str, fields: VespaDocumentFields) -> int:
"""
Updates all chunks for a document with the specified fields.
None values mean that the field does not need an update.

View File

@@ -1,5 +1,6 @@
schema DANSWER_CHUNK_NAME {
document DANSWER_CHUNK_NAME {
TENANT_ID_REPLACEMENT
# Not to be confused with the UUID generated for this chunk which is called documentid by default
field document_id type string {
indexing: summary | attribute

View File

@@ -7,11 +7,13 @@ from datetime import timezone
from typing import Any
from typing import cast
import httpx
import requests
from retry import retry
from danswer.configs.app_configs import LOG_VESPA_TIMING_INFORMATION
from danswer.document_index.interfaces import VespaChunkRequest
from danswer.document_index.vespa.shared_utils.utils import get_vespa_http_client
from danswer.document_index.vespa.shared_utils.vespa_request_builders import (
build_vespa_filters,
)
@@ -293,13 +295,12 @@ def query_vespa(
if LOG_VESPA_TIMING_INFORMATION
else {},
)
try:
response = requests.post(
SEARCH_ENDPOINT,
json=params,
)
response.raise_for_status()
except requests.HTTPError as e:
with get_vespa_http_client() as http_client:
response = http_client.post(SEARCH_ENDPOINT, json=params)
response.raise_for_status()
except httpx.HTTPError as e:
request_info = f"Headers: {response.request.headers}\nPayload: {params}"
response_info = (
f"Status Code: {response.status_code}\n"
@@ -312,9 +313,10 @@ def query_vespa(
f"{response_info}\n"
f"Exception: {e}"
)
raise requests.HTTPError(error_base) from e
raise httpx.HTTPError(error_base) from e
response_json: dict[str, Any] = response.json()
if LOG_VESPA_TIMING_INFORMATION:
logger.debug("Vespa timing info: %s", response_json.get("timing"))
hits = response_json["root"].get("children", [])

View File

@@ -4,17 +4,20 @@ import logging
import os
import re
import time
import urllib
import zipfile
from dataclasses import dataclass
from datetime import datetime
from datetime import timedelta
from typing import BinaryIO
from typing import cast
from typing import List
import httpx
import requests
import httpx # type: ignore
import requests # type: ignore
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.chat_configs import DOC_TIME_DECAY
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.chat_configs import TITLE_CONTENT_RATIO
@@ -39,6 +42,7 @@ from danswer.document_index.vespa.indexing_utils import clean_chunk_id_copy
from danswer.document_index.vespa.indexing_utils import (
get_existing_documents_from_chunks,
)
from danswer.document_index.vespa.shared_utils.utils import get_vespa_http_client
from danswer.document_index.vespa.shared_utils.utils import (
replace_invalid_doc_id_characters,
)
@@ -57,6 +61,8 @@ from danswer.document_index.vespa_constants import DOCUMENT_SETS
from danswer.document_index.vespa_constants import HIDDEN
from danswer.document_index.vespa_constants import NUM_THREADS
from danswer.document_index.vespa_constants import SEARCH_THREAD_NUMBER_PAT
from danswer.document_index.vespa_constants import TENANT_ID_PAT
from danswer.document_index.vespa_constants import TENANT_ID_REPLACEMENT
from danswer.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT
from danswer.document_index.vespa_constants import VESPA_DIM_REPLACEMENT_PAT
from danswer.document_index.vespa_constants import VESPA_TIMEOUT
@@ -69,6 +75,7 @@ from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
from shared_configs.model_server_models import Embedding
logger = setup_logger()
# Set the logging level to WARNING to ignore INFO and DEBUG logs
@@ -92,7 +99,7 @@ def in_memory_zip_from_file_bytes(file_contents: dict[str, bytes]) -> BinaryIO:
return zip_buffer
def _create_document_xml_lines(doc_names: list[str | None]) -> str:
def _create_document_xml_lines(doc_names: list[str | None] | list[str]) -> str:
doc_lines = [
f'<document type="{doc_name}" mode="index" />'
for doc_name in doc_names
@@ -117,15 +124,28 @@ def add_ngrams_to_schema(schema_content: str) -> str:
class VespaIndex(DocumentIndex):
def __init__(self, index_name: str, secondary_index_name: str | None) -> None:
def __init__(
self,
index_name: str,
secondary_index_name: str | None,
multitenant: bool = False,
) -> None:
self.index_name = index_name
self.secondary_index_name = secondary_index_name
self.multitenant = multitenant
self.http_client = get_vespa_http_client()
def ensure_indices_exist(
self,
index_embedding_dim: int,
secondary_index_embedding_dim: int | None,
) -> None:
if MULTI_TENANT:
logger.info(
"Skipping Vespa index seup for multitenant (would wipe all indices)"
)
return None
deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate"
logger.info(f"Deploying Vespa application package to {deploy_url}")
@@ -173,10 +193,14 @@ class VespaIndex(DocumentIndex):
with open(schema_file, "r") as schema_f:
schema_template = schema_f.read()
schema_template = schema_template.replace(TENANT_ID_PAT, "")
schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, self.index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(index_embedding_dim))
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
schema = schema.replace(TENANT_ID_PAT, "")
zip_dict[f"schemas/{schema_names[0]}.sd"] = schema.encode("utf-8")
if self.secondary_index_name:
@@ -194,6 +218,91 @@ class VespaIndex(DocumentIndex):
f"Failed to prepare Vespa Danswer Index. Response: {response.text}"
)
@staticmethod
def register_multitenant_indices(
indices: list[str],
embedding_dims: list[int],
) -> None:
if not MULTI_TENANT:
raise ValueError("Multi-tenant is not enabled")
deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate"
logger.info(f"Deploying Vespa application package to {deploy_url}")
vespa_schema_path = os.path.join(
os.getcwd(), "danswer", "document_index", "vespa", "app_config"
)
schema_file = os.path.join(vespa_schema_path, "schemas", "danswer_chunk.sd")
services_file = os.path.join(vespa_schema_path, "services.xml")
overrides_file = os.path.join(vespa_schema_path, "validation-overrides.xml")
with open(services_file, "r") as services_f:
services_template = services_f.read()
# Generate schema names from index settings
schema_names = [index_name for index_name in indices]
full_schemas = schema_names
doc_lines = _create_document_xml_lines(full_schemas)
services = services_template.replace(DOCUMENT_REPLACEMENT_PAT, doc_lines)
services = services.replace(
SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS)
)
kv_store = get_kv_store()
needs_reindexing = False
try:
needs_reindexing = cast(bool, kv_store.load(KV_REINDEX_KEY))
except Exception:
logger.debug("Could not load the reindexing flag. Using ngrams")
with open(overrides_file, "r") as overrides_f:
overrides_template = overrides_f.read()
# Vespa requires an override to erase data including the indices we're no longer using
# It also has a 30 day cap from current so we set it to 7 dynamically
now = datetime.now()
date_in_7_days = now + timedelta(days=7)
formatted_date = date_in_7_days.strftime("%Y-%m-%d")
overrides = overrides_template.replace(DATE_REPLACEMENT, formatted_date)
zip_dict = {
"services.xml": services.encode("utf-8"),
"validation-overrides.xml": overrides.encode("utf-8"),
}
with open(schema_file, "r") as schema_f:
schema_template = schema_f.read()
for i, index_name in enumerate(indices):
embedding_dim = embedding_dims[i]
logger.info(
f"Creating index: {index_name} with embedding dimension: {embedding_dim}"
)
schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(embedding_dim))
schema = schema.replace(
TENANT_ID_PAT, TENANT_ID_REPLACEMENT if MULTI_TENANT else ""
)
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
zip_dict[f"schemas/{index_name}.sd"] = schema.encode("utf-8")
zip_file = in_memory_zip_from_file_bytes(zip_dict)
headers = {"Content-Type": "application/zip"}
response = requests.post(deploy_url, headers=headers, data=zip_file)
if response.status_code != 200:
raise RuntimeError(
f"Failed to prepare Vespa Danswer Indexes. Response: {response.text}"
)
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
@@ -211,7 +320,7 @@ class VespaIndex(DocumentIndex):
# indexing / updates / deletes since we have to make a large volume of requests.
with (
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
httpx.Client(http2=True) as http_client,
get_vespa_http_client() as http_client,
):
# Check for existing documents, existing documents need to have all of their chunks deleted
# prior to indexing as the document size (num chunks) may have shrunk
@@ -239,6 +348,7 @@ class VespaIndex(DocumentIndex):
chunks=chunk_batch,
index_name=self.index_name,
http_client=http_client,
multitenant=self.multitenant,
executor=executor,
)
@@ -273,9 +383,10 @@ class VespaIndex(DocumentIndex):
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficient for
# indexing / updates / deletes since we have to make a large volume of requests.
with (
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
httpx.Client(http2=True) as http_client,
get_vespa_http_client() as http_client,
):
for update_batch in batch_generator(updates, batch_size):
future_to_document_id = {
@@ -384,12 +495,14 @@ class VespaIndex(DocumentIndex):
time.monotonic() - update_start,
)
def update_single(self, doc_id: str, fields: VespaDocumentFields) -> None:
def update_single(self, doc_id: str, fields: VespaDocumentFields) -> int:
"""Note: if the document id does not exist, the update will be a no-op and the
function will complete with no errors or exceptions.
Handle other exceptions if you wish to implement retry behavior
"""
total_chunks_updated = 0
# Handle Vespa character limitations
# Mutating update_request but it's not used later anyway
normalized_doc_id = replace_invalid_doc_id_characters(doc_id)
@@ -411,13 +524,13 @@ class VespaIndex(DocumentIndex):
if not update_dict["fields"]:
logger.error("Update request received but nothing to update")
return
return 0
index_names = [self.index_name]
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
with httpx.Client(http2=True) as http_client:
with get_vespa_http_client() as http_client:
for index_name in index_names:
params = httpx.QueryParams(
{
@@ -426,7 +539,6 @@ class VespaIndex(DocumentIndex):
}
)
total_chunks_updated = 0
while True:
try:
resp = http_client.put(
@@ -462,9 +574,10 @@ class VespaIndex(DocumentIndex):
f"VespaIndex.update_single: "
f"index={index_name} "
f"doc={normalized_doc_id} "
f"chunks_deleted={total_chunks_updated}"
f"chunks_updated={total_chunks_updated}"
)
return
return total_chunks_updated
def delete(self, doc_ids: list[str]) -> None:
logger.info(f"Deleting {len(doc_ids)} documents from Vespa")
@@ -473,7 +586,7 @@ class VespaIndex(DocumentIndex):
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficial for
# indexing / updates / deletes since we have to make a large volume of requests.
with httpx.Client(http2=True) as http_client:
with get_vespa_http_client() as http_client:
index_names = [self.index_name]
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
@@ -484,10 +597,12 @@ class VespaIndex(DocumentIndex):
)
return
def delete_single(self, doc_id: str) -> None:
def delete_single(self, doc_id: str) -> int:
"""Possibly faster overall than the delete method due to using a single
delete call with a selection query."""
total_chunks_deleted = 0
# Vespa deletion is poorly documented ... luckily we found this
# https://docs.vespa.ai/en/operations/batch-delete.html#example
@@ -499,7 +614,7 @@ class VespaIndex(DocumentIndex):
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
with httpx.Client(http2=True) as http_client:
with get_vespa_http_client() as http_client:
for index_name in index_names:
params = httpx.QueryParams(
{
@@ -508,7 +623,6 @@ class VespaIndex(DocumentIndex):
}
)
total_chunks_deleted = 0
while True:
try:
resp = http_client.delete(
@@ -543,7 +657,8 @@ class VespaIndex(DocumentIndex):
f"doc={doc_id} "
f"chunks_deleted={total_chunks_deleted}"
)
return
return total_chunks_deleted
def id_based_retrieval(
self,
@@ -639,3 +754,158 @@ class VespaIndex(DocumentIndex):
}
return query_vespa(params)
@classmethod
def delete_entries_by_tenant_id(cls, tenant_id: str, index_name: str) -> None:
"""
Deletes all entries in the specified index with the given tenant_id.
Parameters:
tenant_id (str): The tenant ID whose documents are to be deleted.
index_name (str): The name of the index from which to delete documents.
"""
logger.info(
f"Deleting entries with tenant_id: {tenant_id} from index: {index_name}"
)
# Step 1: Retrieve all document IDs with the given tenant_id
document_ids = cls._get_all_document_ids_by_tenant_id(tenant_id, index_name)
if not document_ids:
logger.info(
f"No documents found with tenant_id: {tenant_id} in index: {index_name}"
)
return
# Step 2: Delete documents in batches
delete_requests = [
_VespaDeleteRequest(document_id=doc_id, index_name=index_name)
for doc_id in document_ids
]
cls._apply_deletes_batched(delete_requests)
@classmethod
def _get_all_document_ids_by_tenant_id(
cls, tenant_id: str, index_name: str
) -> List[str]:
"""
Retrieves all document IDs with the specified tenant_id, handling pagination.
Parameters:
tenant_id (str): The tenant ID to search for.
index_name (str): The name of the index to search in.
Returns:
List[str]: A list of document IDs matching the tenant_id.
"""
offset = 0
limit = 1000 # Vespa's maximum hits per query
document_ids = []
logger.debug(
f"Starting document ID retrieval for tenant_id: {tenant_id} in index: {index_name}"
)
while True:
# Construct the query to fetch document IDs
query_params = {
"yql": f'select id from sources * where tenant_id contains "{tenant_id}";',
"offset": str(offset),
"hits": str(limit),
"timeout": "10s",
"format": "json",
"summary": "id",
}
url = f"{VESPA_APPLICATION_ENDPOINT}/search/"
logger.debug(
f"Querying for document IDs with tenant_id: {tenant_id}, offset: {offset}"
)
with get_vespa_http_client(no_timeout=True) as http_client:
response = http_client.get(url, params=query_params)
response.raise_for_status()
search_result = response.json()
hits = search_result.get("root", {}).get("children", [])
if not hits:
break
for hit in hits:
doc_id = hit.get("id")
if doc_id:
document_ids.append(doc_id)
offset += limit # Move to the next page
logger.debug(
f"Retrieved {len(document_ids)} document IDs for tenant_id: {tenant_id}"
)
return document_ids
@classmethod
def _apply_deletes_batched(
cls,
delete_requests: List["_VespaDeleteRequest"],
batch_size: int = BATCH_SIZE,
) -> None:
"""
Deletes documents in batches using multiple threads.
Parameters:
delete_requests (List[_VespaDeleteRequest]): The list of delete requests.
batch_size (int): The number of documents to delete in each batch.
"""
def _delete_document(
delete_request: "_VespaDeleteRequest", http_client: httpx.Client
) -> None:
logger.debug(f"Deleting document with ID {delete_request.document_id}")
response = http_client.delete(
delete_request.url,
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
logger.debug(f"Starting batch deletion for {len(delete_requests)} documents")
with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
with get_vespa_http_client(no_timeout=True) as http_client:
for batch_start in range(0, len(delete_requests), batch_size):
batch = delete_requests[batch_start : batch_start + batch_size]
future_to_document_id = {
executor.submit(
_delete_document,
delete_request,
http_client,
): delete_request.document_id
for delete_request in batch
}
for future in concurrent.futures.as_completed(
future_to_document_id
):
doc_id = future_to_document_id[future]
try:
future.result()
logger.debug(f"Successfully deleted document: {doc_id}")
except httpx.HTTPError as e:
logger.error(f"Failed to delete document {doc_id}: {e}")
# Optionally, implement retry logic or error handling here
logger.info("Batch deletion completed")
class _VespaDeleteRequest:
def __init__(self, document_id: str, index_name: str) -> None:
self.document_id = document_id
# Encode the document ID to ensure it's safe for use in the URL
encoded_doc_id = urllib.parse.quote_plus(self.document_id)
self.url = (
f"{VESPA_APPLICATION_ENDPOINT}/document/v1/"
f"{index_name}/{index_name}/docid/{encoded_doc_id}"
)

View File

@@ -37,6 +37,7 @@ from danswer.document_index.vespa_constants import SEMANTIC_IDENTIFIER
from danswer.document_index.vespa_constants import SKIP_TITLE_EMBEDDING
from danswer.document_index.vespa_constants import SOURCE_LINKS
from danswer.document_index.vespa_constants import SOURCE_TYPE
from danswer.document_index.vespa_constants import TENANT_ID
from danswer.document_index.vespa_constants import TITLE
from danswer.document_index.vespa_constants import TITLE_EMBEDDING
from danswer.indexing.models import DocMetadataAwareIndexChunk
@@ -65,6 +66,8 @@ def _does_document_exist(
raise RuntimeError(
f"Unexpected fetch document by ID value from Vespa "
f"with error {doc_fetch_response.status_code}"
f"Index name: {index_name}"
f"Doc chunk id: {doc_chunk_id}"
)
return True
@@ -117,7 +120,10 @@ def get_existing_documents_from_chunks(
@retry(tries=3, delay=1, backoff=2)
def _index_vespa_chunk(
chunk: DocMetadataAwareIndexChunk, index_name: str, http_client: httpx.Client
chunk: DocMetadataAwareIndexChunk,
index_name: str,
http_client: httpx.Client,
multitenant: bool,
) -> None:
json_header = {
"Content-Type": "application/json",
@@ -174,6 +180,10 @@ def _index_vespa_chunk(
BOOST: chunk.boost,
}
if multitenant:
if chunk.tenant_id:
vespa_document_fields[TENANT_ID] = chunk.tenant_id
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_chunk_id}"
logger.debug(f'Indexing to URL "{vespa_url}"')
res = http_client.post(
@@ -192,6 +202,7 @@ def batch_index_vespa_chunks(
chunks: list[DocMetadataAwareIndexChunk],
index_name: str,
http_client: httpx.Client,
multitenant: bool,
executor: concurrent.futures.ThreadPoolExecutor | None = None,
) -> None:
external_executor = True
@@ -202,7 +213,9 @@ def batch_index_vespa_chunks(
try:
chunk_index_future = {
executor.submit(_index_vespa_chunk, chunk, index_name, http_client): chunk
executor.submit(
_index_vespa_chunk, chunk, index_name, http_client, multitenant
): chunk
for chunk in chunks
}
for future in concurrent.futures.as_completed(chunk_index_future):

View File

@@ -1,4 +1,12 @@
import re
from typing import cast
import httpx
from danswer.configs.app_configs import MANAGED_VESPA
from danswer.configs.app_configs import VESPA_CLOUD_CERT_PATH
from danswer.configs.app_configs import VESPA_CLOUD_KEY_PATH
from danswer.configs.app_configs import VESPA_REQUEST_TIMEOUT
# NOTE: This does not seem to be used in reality despite the Vespa Docs pointing to this code
# See here for reference: https://docs.vespa.ai/en/documents.html
@@ -45,3 +53,19 @@ def remove_invalid_unicode_chars(text: str) -> str:
"[\x00-\x08\x0b\x0c\x0e-\x1F\uD800-\uDFFF\uFFFE\uFFFF]"
)
return _illegal_xml_chars_RE.sub("", text)
def get_vespa_http_client(no_timeout: bool = False) -> httpx.Client:
"""
Configure and return an HTTP client for communicating with Vespa,
including authentication if needed.
"""
return httpx.Client(
cert=cast(tuple[str, str], (VESPA_CLOUD_CERT_PATH, VESPA_CLOUD_KEY_PATH))
if MANAGED_VESPA
else None,
verify=False if not MANAGED_VESPA else True,
timeout=None if no_timeout else VESPA_REQUEST_TIMEOUT,
http2=True,
)

View File

@@ -12,6 +12,7 @@ from danswer.document_index.vespa_constants import DOCUMENT_SETS
from danswer.document_index.vespa_constants import HIDDEN
from danswer.document_index.vespa_constants import METADATA_LIST
from danswer.document_index.vespa_constants import SOURCE_TYPE
from danswer.document_index.vespa_constants import TENANT_ID
from danswer.search.models import IndexFilters
from danswer.utils.logger import setup_logger
@@ -53,6 +54,9 @@ def build_vespa_filters(filters: IndexFilters, include_hidden: bool = False) ->
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
if filters.tenant_id:
filter_str += f'({TENANT_ID} contains "{filters.tenant_id}") and '
# CAREFUL touching this one, currently there is no second ACL double-check post retrieval
if filters.access_control_list is not None:
filter_str += _build_or_filters(

View File

@@ -1,3 +1,4 @@
from danswer.configs.app_configs import VESPA_CLOUD_URL
from danswer.configs.app_configs import VESPA_CONFIG_SERVER_HOST
from danswer.configs.app_configs import VESPA_HOST
from danswer.configs.app_configs import VESPA_PORT
@@ -9,17 +10,30 @@ DANSWER_CHUNK_REPLACEMENT_PAT = "DANSWER_CHUNK_NAME"
DOCUMENT_REPLACEMENT_PAT = "DOCUMENT_REPLACEMENT"
SEARCH_THREAD_NUMBER_PAT = "SEARCH_THREAD_NUMBER"
DATE_REPLACEMENT = "DATE_REPLACEMENT"
SEARCH_THREAD_NUMBER_PAT = "SEARCH_THREAD_NUMBER"
TENANT_ID_PAT = "TENANT_ID_REPLACEMENT"
TENANT_ID_REPLACEMENT = """field tenant_id type string {
indexing: summary | attribute
rank: filter
attribute: fast-search
}"""
# config server
VESPA_CONFIG_SERVER_URL = f"http://{VESPA_CONFIG_SERVER_HOST}:{VESPA_TENANT_PORT}"
VESPA_CONFIG_SERVER_URL = (
VESPA_CLOUD_URL or f"http://{VESPA_CONFIG_SERVER_HOST}:{VESPA_TENANT_PORT}"
)
VESPA_APPLICATION_ENDPOINT = f"{VESPA_CONFIG_SERVER_URL}/application/v2"
# main search application
VESPA_APP_CONTAINER_URL = f"http://{VESPA_HOST}:{VESPA_PORT}"
VESPA_APP_CONTAINER_URL = VESPA_CLOUD_URL or f"http://{VESPA_HOST}:{VESPA_PORT}"
# danswer_chunk below is defined in vespa/app_configs/schemas/danswer_chunk.sd
DOCUMENT_ID_ENDPOINT = (
f"{VESPA_APP_CONTAINER_URL}/document/v1/default/{{index_name}}/docid"
)
SEARCH_ENDPOINT = f"{VESPA_APP_CONTAINER_URL}/search/"
NUM_THREADS = (
@@ -35,7 +49,7 @@ MAX_OR_CONDITIONS = 10
VESPA_TIMEOUT = "3s"
BATCH_SIZE = 128 # Specific to Vespa
TENANT_ID = "tenant_id"
DOCUMENT_ID = "document_id"
CHUNK_ID = "chunk_id"
BLURB = "blurb"

View File

@@ -208,8 +208,9 @@ def read_pdf_file(
# By user request, keep files that are unreadable just so they
# can be discoverable by title.
return "", metadata
else:
logger.warning("No Password available to to decrypt pdf")
elif pdf_reader.is_encrypted:
logger.warning("No Password available to decrypt pdf, returning empty")
return "", metadata
# Extract metadata from the PDF, removing leading '/' from keys if present
# This standardizes the metadata keys for consistency

View File

@@ -4,11 +4,17 @@ from dataclasses import dataclass
from typing import IO
import bs4
import trafilatura # type: ignore
from trafilatura.settings import use_config # type: ignore
from danswer.configs.app_configs import HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY
from danswer.configs.app_configs import PARSE_WITH_TRAFILATURA
from danswer.configs.app_configs import WEB_CONNECTOR_IGNORED_CLASSES
from danswer.configs.app_configs import WEB_CONNECTOR_IGNORED_ELEMENTS
from danswer.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
from danswer.utils.logger import setup_logger
logger = setup_logger()
MINTLIFY_UNWANTED = ["sticky", "hidden"]
@@ -47,6 +53,18 @@ def format_element_text(element_text: str, link_href: str | None) -> str:
return f"[{element_text_no_newlines}]({link_href})"
def parse_html_with_trafilatura(html_content: str) -> str:
"""Parse HTML content using trafilatura."""
config = use_config()
config.set("DEFAULT", "include_links", "True")
config.set("DEFAULT", "include_tables", "True")
config.set("DEFAULT", "include_images", "True")
config.set("DEFAULT", "include_formatting", "True")
extracted_text = trafilatura.extract(html_content, config=config)
return strip_excessive_newlines_and_spaces(extracted_text) if extracted_text else ""
def format_document_soup(
document: bs4.BeautifulSoup, table_cell_separator: str = "\t"
) -> str:
@@ -183,7 +201,21 @@ def web_html_cleanup(
for undesired_tag in additional_element_types_to_discard:
[tag.extract() for tag in soup.find_all(undesired_tag)]
# 200B is ZeroWidthSpace which we don't care for
page_text = format_document_soup(soup).replace("\u200B", "")
soup_string = str(soup)
page_text = ""
return ParsedHTML(title=title, cleaned_text=page_text)
if PARSE_WITH_TRAFILATURA:
try:
page_text = parse_html_with_trafilatura(soup_string)
if not page_text:
raise ValueError("Empty content returned by trafilatura.")
except Exception as e:
logger.info(f"Trafilatura parsing failed: {e}. Falling back on bs4.")
page_text = format_document_soup(soup)
else:
page_text = format_document_soup(soup)
# 200B is ZeroWidthSpace which we don't care for
cleaned_text = page_text.replace("\u200B", "")
return ParsedHTML(title=title, cleaned_text=cleaned_text)

View File

@@ -15,7 +15,7 @@ from danswer.indexing.models import DocAwareChunk
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import shared_precompare_cleanup
from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT
# Not supporting overlaps, we need a clean combination of chunks and it is unclear if overlaps
# actually help quality at all
@@ -27,6 +27,7 @@ CHUNK_OVERLAP = 0
MAX_METADATA_PERCENTAGE = 0.25
CHUNK_MIN_CONTENT = 256
logger = setup_logger()
@@ -157,6 +158,24 @@ class Chunker:
else None
)
def _split_oversized_chunk(self, text: str, content_token_limit: int) -> list[str]:
"""
Splits the text into smaller chunks based on token count to ensure
no chunk exceeds the content_token_limit.
"""
tokens = self.tokenizer.tokenize(text)
chunks = []
start = 0
total_tokens = len(tokens)
while start < total_tokens:
end = min(start + content_token_limit, total_tokens)
token_chunk = tokens[start:end]
# Join the tokens to reconstruct the text
chunk_text = " ".join(token_chunk)
chunks.append(chunk_text)
start = end
return chunks
def _extract_blurb(self, text: str) -> str:
texts = self.blurb_splitter.split_text(text)
if not texts:
@@ -217,14 +236,42 @@ class Chunker:
chunk_text = ""
split_texts = self.chunk_splitter.split_text(section_text)
for i, split_text in enumerate(split_texts):
chunks.append(
_create_chunk(
text=split_text,
links={0: section_link_text},
is_continuation=(i != 0),
split_token_count = len(self.tokenizer.tokenize(split_text))
if STRICT_CHUNK_TOKEN_LIMIT:
split_token_count = len(self.tokenizer.tokenize(split_text))
if split_token_count > content_token_limit:
# Further split the oversized chunk
smaller_chunks = self._split_oversized_chunk(
split_text, content_token_limit
)
for i, small_chunk in enumerate(smaller_chunks):
chunks.append(
_create_chunk(
text=small_chunk,
links={0: section_link_text},
is_continuation=(i != 0),
)
)
else:
chunks.append(
_create_chunk(
text=split_text,
links={0: section_link_text},
)
)
else:
chunks.append(
_create_chunk(
text=split_text,
links={0: section_link_text},
is_continuation=(i != 0),
)
)
)
continue
current_token_count = len(self.tokenizer.tokenize(chunk_text))

View File

@@ -32,6 +32,8 @@ class IndexingEmbedder(ABC):
provider_type: EmbeddingProvider | None,
api_key: str | None,
api_url: str | None,
api_version: str | None,
deployment_name: str | None,
heartbeat: Heartbeat | None,
):
self.model_name = model_name
@@ -41,6 +43,8 @@ class IndexingEmbedder(ABC):
self.provider_type = provider_type
self.api_key = api_key
self.api_url = api_url
self.api_version = api_version
self.deployment_name = deployment_name
self.embedding_model = EmbeddingModel(
model_name=model_name,
@@ -50,6 +54,8 @@ class IndexingEmbedder(ABC):
api_key=api_key,
provider_type=provider_type,
api_url=api_url,
api_version=api_version,
deployment_name=deployment_name,
# The below are globally set, this flow always uses the indexing one
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
@@ -75,6 +81,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type: EmbeddingProvider | None = None,
api_key: str | None = None,
api_url: str | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
heartbeat: Heartbeat | None = None,
):
super().__init__(
@@ -85,6 +93,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type,
api_key,
api_url,
api_version,
deployment_name,
heartbeat,
)
@@ -193,5 +203,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type=search_settings.provider_type,
api_key=search_settings.api_key,
api_url=search_settings.api_url,
api_version=search_settings.api_version,
deployment_name=search_settings.deployment_name,
heartbeat=heartbeat,
)

View File

@@ -137,6 +137,7 @@ def index_doc_batch_with_handler(
attempt_id: int | None,
db_session: Session,
ignore_time_skip: bool = False,
tenant_id: str | None = None,
) -> tuple[int, int]:
r = (0, 0)
try:
@@ -148,6 +149,7 @@ def index_doc_batch_with_handler(
index_attempt_metadata=index_attempt_metadata,
db_session=db_session,
ignore_time_skip=ignore_time_skip,
tenant_id=tenant_id,
)
except Exception as e:
if INDEXING_EXCEPTION_LIMIT == 0:
@@ -261,6 +263,7 @@ def index_doc_batch(
index_attempt_metadata: IndexAttemptMetadata,
db_session: Session,
ignore_time_skip: bool = False,
tenant_id: str | None = None,
) -> tuple[int, int]:
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
Note that the documents should already be batched at this point so that it does not inflate the
@@ -324,6 +327,7 @@ def index_doc_batch(
if chunk.source_document.id in ctx.id_to_db_doc_map
else DEFAULT_BOOST
),
tenant_id=tenant_id,
)
for chunk in chunks_with_embeddings
]
@@ -373,6 +377,7 @@ def build_indexing_pipeline(
chunker: Chunker | None = None,
ignore_time_skip: bool = False,
attempt_id: int | None = None,
tenant_id: str | None = None,
) -> IndexingPipelineProtocol:
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
search_settings = get_current_search_settings(db_session)
@@ -416,4 +421,5 @@ def build_indexing_pipeline(
ignore_time_skip=ignore_time_skip,
attempt_id=attempt_id,
db_session=db_session,
tenant_id=tenant_id,
)

View File

@@ -75,6 +75,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
negative -> ranked lower.
"""
tenant_id: str | None = None
access: "DocumentAccess"
document_sets: set[str]
boost: int
@@ -86,6 +87,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access: "DocumentAccess",
document_sets: set[str],
boost: int,
tenant_id: str | None,
) -> "DocMetadataAwareIndexChunk":
index_chunk_data = index_chunk.model_dump()
return cls(
@@ -93,6 +95,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access=access,
document_sets=document_sets,
boost=boost,
tenant_id=tenant_id,
)

View File

@@ -3,15 +3,21 @@ from collections.abc import Iterator
from contextlib import contextmanager
from typing import cast
from fastapi import HTTPException
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.configs.app_configs import MULTI_TENANT
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import is_valid_schema_name
from danswer.db.models import KVStore
from danswer.key_value_store.interface import JSON_ro
from danswer.key_value_store.interface import KeyValueStore
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from shared_configs.configs import current_tenant_id
logger = setup_logger()
@@ -22,12 +28,23 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
class PgRedisKVStore(KeyValueStore):
def __init__(self) -> None:
self.redis_client = get_redis_client()
tenant_id = current_tenant_id.get()
self.redis_client = get_redis_client(tenant_id=tenant_id)
@contextmanager
def get_session(self) -> Iterator[Session]:
engine = get_sqlalchemy_engine()
with Session(engine, expire_on_commit=False) as session:
if MULTI_TENANT:
tenant_id = current_tenant_id.get()
if tenant_id == "public":
raise HTTPException(
status_code=401, detail="User must authenticate"
)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Set the search_path to the tenant's schema
session.execute(text(f'SET search_path = "{tenant_id}"'))
yield session
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:

View File

@@ -79,6 +79,8 @@ def _get_answer_stream_processor(
doc_id_to_rank_map: DocumentIdOrderMapping,
answer_style_configs: AnswerStyleConfig,
) -> StreamProcessor:
print("ANSWERR STYES")
print(answer_style_configs.__dict__)
if answer_style_configs.citation_config:
return build_citation_processor(
context_docs=context_docs, doc_id_to_rank_map=doc_id_to_rank_map
@@ -316,7 +318,9 @@ class Answer:
yield from self._process_llm_stream(
prompt=prompt,
tools=[tool.tool_definition() for tool in self.tools],
# as of now, we don't support multiple tool calls in sequence, which is why
# we don't need to pass this in here
# tools=[tool.tool_definition() for tool in self.tools],
)
return

View File

@@ -226,6 +226,7 @@ def process_model_tokens(
hold_quote = ""
for token in tokens:
print(f"Token: {token}")
model_previous = model_output
model_output += token

View File

@@ -109,7 +109,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
"arguments": json.dumps(tool_call["args"]),
},
"type": "function",
"index": 0, # only support a single tool call atm
"index": tool_call.get("index", 0),
}
for tool_call in message.tool_calls
]
@@ -158,12 +158,13 @@ def _convert_delta_to_message_chunk(
if tool_calls:
tool_call = tool_calls[0]
tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or ""
idx = tool_call.index
tool_call_chunk = ToolCallChunk(
name=tool_name,
id=tool_call.id,
args=tool_call.function.arguments,
index=0, # only support a single tool call atm
index=idx,
)
return AIMessageChunk(
@@ -204,6 +205,7 @@ class DefaultMultiLLM(LLM):
model_name: str,
api_base: str | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
max_output_tokens: int | None = None,
custom_llm_provider: str | None = None,
temperature: float = GEN_AI_TEMPERATURE,
@@ -215,6 +217,7 @@ class DefaultMultiLLM(LLM):
self._model_version = model_name
self._temperature = temperature
self._api_key = api_key
self._deployment_name = deployment_name
self._api_base = api_base
self._api_version = api_version
self._custom_llm_provider = custom_llm_provider
@@ -283,13 +286,14 @@ class DefaultMultiLLM(LLM):
_convert_message_to_dict(msg) if isinstance(msg, BaseMessage) else msg
for msg in prompt
]
elif isinstance(prompt, str):
prompt = [_convert_message_to_dict(HumanMessage(content=prompt))]
try:
return litellm.completion(
# model choice
model=f"{self.config.model_provider}/{self.config.model_name}",
model=f"{self.config.model_provider}/{self.config.deployment_name or self.config.model_name}",
# NOTE: have to pass in None instead of empty string for these
# otherwise litellm can have some issues with bedrock
api_key=self._api_key or None,
@@ -324,6 +328,7 @@ class DefaultMultiLLM(LLM):
api_key=self._api_key,
api_base=self._api_base,
api_version=self._api_version,
deployment_name=self._deployment_name,
)
def _invoke_implementation(

View File

@@ -7,9 +7,9 @@ from danswer.db.llm import fetch_provider
from danswer.db.models import Persona
from danswer.llm.chat_llm import DefaultMultiLLM
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.headers import build_llm_extra_headers
from danswer.llm.interfaces import LLM
from danswer.llm.override_models import LLMOverride
from danswer.utils.headers import build_llm_extra_headers
def get_main_llm_from_tuple(
@@ -88,6 +88,7 @@ def get_default_llms(
return get_llm(
provider=llm_provider.provider,
model=model,
deployment_name=llm_provider.deployment_name,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
@@ -103,6 +104,7 @@ def get_default_llms(
def get_llm(
provider: str,
model: str,
deployment_name: str | None = None,
api_key: str | None = None,
api_base: str | None = None,
api_version: str | None = None,

View File

@@ -1,34 +0,0 @@
from fastapi.datastructures import Headers
from danswer.configs.model_configs import LITELLM_EXTRA_HEADERS
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
def get_litellm_additional_request_headers(
headers: dict[str, str] | Headers
) -> dict[str, str]:
if not LITELLM_PASS_THROUGH_HEADERS:
return {}
pass_through_headers: dict[str, str] = {}
for key in LITELLM_PASS_THROUGH_HEADERS:
if key in headers:
pass_through_headers[key] = headers[key]
else:
# fastapi makes all header keys lowercase, handling that here
lowercase_key = key.lower()
if lowercase_key in headers:
pass_through_headers[lowercase_key] = headers[lowercase_key]
return pass_through_headers
def build_llm_extra_headers(
additional_headers: dict[str, str] | None = None
) -> dict[str, str]:
extra_headers: dict[str, str] = {}
if additional_headers:
extra_headers.update(additional_headers)
if LITELLM_EXTRA_HEADERS:
extra_headers.update(LITELLM_EXTRA_HEADERS)
return extra_headers

View File

@@ -24,7 +24,7 @@ class LLMConfig(BaseModel):
api_key: str | None = None
api_base: str | None = None
api_version: str | None = None
deployment_name: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}

View File

@@ -16,10 +16,13 @@ class WellKnownLLMProviderDescriptor(BaseModel):
api_base_required: bool
api_version_required: bool
custom_config_keys: list[CustomConfigKey] | None = None
llm_names: list[str]
default_model: str | None = None
default_fast_model: str | None = None
# set for providers like Azure, which require a deployment name.
deployment_name_required: bool = False
# set for providers like Azure, which support a single model per deployment.
single_model_supported: bool = False
OPENAI_PROVIDER_NAME = "openai"
@@ -108,6 +111,8 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
api_version_required=True,
custom_config_keys=[],
llm_names=fetch_models_for_provider(AZURE_PROVIDER_NAME),
deployment_name_required=True,
single_model_supported=True,
),
WellKnownLLMProviderDescriptor(
name=BEDROCK_PROVIDER_NAME,

View File

@@ -1,9 +1,11 @@
import sys
import traceback
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Any
from typing import cast
import sentry_sdk
import uvicorn
from fastapi import APIRouter
from fastapi import FastAPI
@@ -14,6 +16,8 @@ from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from httpx_oauth.clients.google import GoogleOAuth2
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.starlette import StarletteIntegration
from sqlalchemy.orm import Session
from danswer import __version__
@@ -28,10 +32,12 @@ from danswer.configs.app_configs import APP_PORT
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import LOG_ENDPOINT_LATENCY
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import OAUTH_CLIENT_ID
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
from danswer.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
from danswer.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
from danswer.configs.app_configs import SYSTEM_RECURSION_LIMIT
from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import AuthType
@@ -51,6 +57,7 @@ from danswer.server.features.input_prompt.api import (
admin_router as admin_input_prompt_router,
)
from danswer.server.features.input_prompt.api import basic_router as input_prompt_router
from danswer.server.features.notifications.api import router as notification_router
from danswer.server.features.persona.api import admin_router as admin_persona_router
from danswer.server.features.persona.api import basic_router as persona_router
from danswer.server.features.prompt.api import basic_router as prompt_router
@@ -78,6 +85,7 @@ from danswer.server.token_rate_limits.api import (
router as token_rate_limit_settings_router,
)
from danswer.setup import setup_danswer
from danswer.setup import setup_multitenant_danswer
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import get_or_generate_uuid
from danswer.utils.telemetry import optional_telemetry
@@ -86,6 +94,7 @@ from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import CORS_ALLOWED_ORIGIN
from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
@@ -140,6 +149,11 @@ def include_router_with_global_prefix_prepended(
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
# Set recursion limit
if SYSTEM_RECURSION_LIMIT is not None:
sys.setrecursionlimit(SYSTEM_RECURSION_LIMIT)
logger.notice(f"System recursion limit set to {SYSTEM_RECURSION_LIMIT}")
SqlEngine.set_app_name(POSTGRES_WEB_APP_NAME)
SqlEngine.init_engine(
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
@@ -150,6 +164,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
verify_auth = fetch_versioned_implementation(
"danswer.auth.users", "verify_auth_setting"
)
# Will throw exception if an issue is found
verify_auth()
@@ -162,11 +177,15 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
# fill up Postgres connection pools
await warm_up_connections()
# We cache this at the beginning so there is no delay in the first telemetry
get_or_generate_uuid()
if not MULTI_TENANT:
# We cache this at the beginning so there is no delay in the first telemetry
get_or_generate_uuid()
with Session(engine) as db_session:
setup_danswer(db_session)
# If we are multi-tenant, we need to only set up initial public tables
with Session(engine) as db_session:
setup_danswer(db_session)
else:
setup_multitenant_danswer()
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
yield
@@ -190,6 +209,15 @@ def get_application() -> FastAPI:
application = FastAPI(
title="Danswer Backend", version=__version__, lifespan=lifespan
)
if SENTRY_DSN:
sentry_sdk.init(
dsn=SENTRY_DSN,
integrations=[StarletteIntegration(), FastApiIntegration()],
traces_sample_rate=0.5,
)
logger.info("Sentry initialized")
else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
# Add the custom exception handler
application.add_exception_handler(status.HTTP_400_BAD_REQUEST, log_http_error)
@@ -219,6 +247,7 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, admin_persona_router)
include_router_with_global_prefix_prepended(application, input_prompt_router)
include_router_with_global_prefix_prepended(application, admin_input_prompt_router)
include_router_with_global_prefix_prepended(application, notification_router)
include_router_with_global_prefix_prepended(application, prompt_router)
include_router_with_global_prefix_prepended(application, tool_router)
include_router_with_global_prefix_prepended(application, admin_tool_router)
@@ -240,7 +269,7 @@ def get_application() -> FastAPI:
# Server logs this during auth setup verification step
pass
elif AUTH_TYPE == AuthType.BASIC:
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_auth_router(auth_backend),
@@ -272,7 +301,7 @@ def get_application() -> FastAPI:
tags=["users"],
)
elif AUTH_TYPE == AuthType.GOOGLE_OAUTH:
if AUTH_TYPE == AuthType.GOOGLE_OAUTH or AUTH_TYPE == AuthType.CLOUD:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
application,
@@ -288,6 +317,7 @@ def get_application() -> FastAPI:
prefix="/auth/oauth",
tags=["auth"],
)
# Need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended(
application,
@@ -329,7 +359,7 @@ if __name__ == "__main__":
f"Starting Danswer Backend version {__version__} on http://{APP_HOST}:{str(APP_PORT)}/"
)
if global_version.get_is_ee_version():
if global_version.is_ee_version():
logger.notice("Running Enterprise Edition")
uvicorn.run(app, host=APP_HOST, port=APP_PORT)

View File

@@ -50,23 +50,26 @@ def clean_model_name(model_str: str) -> str:
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
_WHITELIST = set(
" !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\n\t"
)
_INITIAL_FILTER = re.compile(
"["
"\U00000080-\U0000FFFF" # All Unicode characters beyond ASCII
"\U00010000-\U0010FFFF" # All Unicode characters in supplementary planes
"\U0000FFF0-\U0000FFFF" # Specials
"\U0001F000-\U0001F9FF" # Emoticons
"\U00002000-\U0000206F" # General Punctuation
"\U00002190-\U000021FF" # Arrows
"\U00002700-\U000027BF" # Dingbats
"]+",
flags=re.UNICODE,
)
def clean_openai_text(text: str) -> str:
# First, remove all weird characters
# Remove specific Unicode ranges that might cause issues
cleaned = _INITIAL_FILTER.sub("", text)
# Then, keep only whitelisted characters
return "".join(char for char in cleaned if char in _WHITELIST)
# Remove any control characters except for newline and tab
cleaned = "".join(ch for ch in cleaned if ch >= " " or ch in "\n\t")
return cleaned
def build_model_server_url(
@@ -97,6 +100,8 @@ class EmbeddingModel:
provider_type: EmbeddingProvider | None,
retrim_content: bool = False,
heartbeat: Heartbeat | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
) -> None:
self.api_key = api_key
self.provider_type = provider_type
@@ -106,6 +111,8 @@ class EmbeddingModel:
self.model_name = model_name
self.retrim_content = retrim_content
self.api_url = api_url
self.api_version = api_version
self.deployment_name = deployment_name
self.tokenizer = get_tokenizer(
model_name=model_name, provider_type=provider_type
)
@@ -157,6 +164,8 @@ class EmbeddingModel:
embed_request = EmbedRequest(
model_name=self.model_name,
texts=text_batch,
api_version=self.api_version,
deployment_name=self.deployment_name,
max_context_length=max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
@@ -239,6 +248,8 @@ class EmbeddingModel:
provider_type=search_settings.provider_type,
api_url=search_settings.api_url,
retrim_content=retrim_content,
api_version=search_settings.api_version,
deployment_name=search_settings.deployment_name,
)

View File

@@ -1,4 +1,7 @@
import functools
import threading
from collections.abc import Callable
from typing import Any
from typing import Optional
import redis
@@ -14,6 +17,72 @@ from danswer.configs.app_configs import REDIS_SSL
from danswer.configs.app_configs import REDIS_SSL_CA_CERTS
from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
from danswer.utils.logger import setup_logger
logger = setup_logger()
class TenantRedis(redis.Redis):
def __init__(self, tenant_id: str, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.tenant_id: str = tenant_id
def _prefixed(self, key: str | bytes | memoryview) -> str | bytes | memoryview:
prefix: str = f"{self.tenant_id}:"
if isinstance(key, str):
if key.startswith(prefix):
return key
else:
return prefix + key
elif isinstance(key, bytes):
prefix_bytes = prefix.encode()
if key.startswith(prefix_bytes):
return key
else:
return prefix_bytes + key
elif isinstance(key, memoryview):
key_bytes = key.tobytes()
prefix_bytes = prefix.encode()
if key_bytes.startswith(prefix_bytes):
return key
else:
return memoryview(prefix_bytes + key_bytes)
else:
raise TypeError(f"Unsupported key type: {type(key)}")
def _prefix_method(self, method: Callable) -> Callable:
@functools.wraps(method)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if "name" in kwargs:
kwargs["name"] = self._prefixed(kwargs["name"])
elif len(args) > 0:
args = (self._prefixed(args[0]),) + args[1:]
return method(*args, **kwargs)
return wrapper
def __getattribute__(self, item: str) -> Any:
original_attr = super().__getattribute__(item)
methods_to_wrap = [
"lock",
"unlock",
"get",
"set",
"delete",
"exists",
"incrby",
"hset",
"hget",
"getset",
"scan_iter",
"owned",
"reacquire",
"create_lock",
"startswith",
] # Add all methods that need prefixing
if item in methods_to_wrap and callable(original_attr):
return self._prefix_method(original_attr)
return original_attr
class RedisPool:
@@ -32,8 +101,10 @@ class RedisPool:
def _init_pool(self) -> None:
self._pool = RedisPool.create_pool(ssl=REDIS_SSL)
def get_client(self) -> Redis:
return redis.Redis(connection_pool=self._pool)
def get_client(self, tenant_id: str | None) -> Redis:
if tenant_id is None:
tenant_id = "public"
return TenantRedis(tenant_id, connection_pool=self._pool)
@staticmethod
def create_pool(
@@ -84,8 +155,8 @@ class RedisPool:
redis_pool = RedisPool()
def get_redis_client() -> Redis:
return redis_pool.get_client()
def get_redis_client(*, tenant_id: str | None) -> Redis:
return redis_pool.get_client(tenant_id)
# # Usage example

View File

@@ -102,6 +102,7 @@ class BaseFilters(BaseModel):
class IndexFilters(BaseFilters):
access_control_list: list[str] | None
tenant_id: str | None = None
class ChunkMetric(BaseModel):

View File

@@ -1,5 +1,6 @@
from sqlalchemy.orm import Session
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.chat_configs import BASE_RECENCY_DECAY
from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW
@@ -9,6 +10,7 @@ from danswer.configs.chat_configs import HYBRID_ALPHA
from danswer.configs.chat_configs import HYBRID_ALPHA_KEYWORD
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.db.engine import current_tenant_id
from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
from danswer.llm.interfaces import LLM
@@ -160,6 +162,7 @@ def retrieval_preprocessing(
time_cutoff=time_filter or predicted_time_cutoff,
tags=preset_filters.tags, # Tags are never auto-extracted
access_control_list=user_acl_filters,
tenant_id=current_tenant_id.get() if MULTI_TENANT else None,
)
llm_evaluation_type = LLMEvaluationType.BASIC

Some files were not shown because too many files have changed in this diff Show More