Compare commits

..

101 Commits

Author SHA1 Message Date
rkuo-danswer
aa187c86e2 Merge pull request #2726 from danswer-ai/bugfix/docker-web-runners
try porting docker web build to runs-on
2024-10-08 14:42:43 -07:00
Richard Kuo (Danswer)
c72c5619f0 remove more flaky tests 2024-10-08 14:42:04 -07:00
Chris Weaver
78e7710f17 Handle bug with initial connector page display (#2727)
* Handle bug with initial connector page display

* Casing consistency
2024-10-08 21:01:37 +00:00
rkuo-danswer
672f5cc5ce urlencode the password part properly before putting it in the broker url (#2719)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-10-08 20:46:11 +00:00
rkuo-danswer
7b3c433ff8 Merge pull request #2717 from danswer-ai/bugfix/docker-legacy-key-value-format
Fix all LegacyKeyValueFormat docker warnings
2024-10-08 13:57:10 -07:00
Richard Kuo (Danswer)
057321a59f disable flaky test 2024-10-08 13:40:35 -07:00
Richard Kuo (Danswer)
5cc46341f7 try porting docker web build to runs-on 2024-10-08 13:11:59 -07:00
Chris Weaver
21a3921790 Better support for image generation capable models (#2725) 2024-10-08 12:41:14 -07:00
Chris Weaver
aa69fe762b Temp patch to remove multiple tool calls (#2720) 2024-10-08 18:08:45 +00:00
pablodanswer
3ef72b8d1a k (#2721) 2024-10-08 09:33:29 -07:00
pablodanswer
a0124e4e50 ensure all timeout -> hook (#2718) 2024-10-08 15:48:38 +00:00
Richard Kuo (Danswer)
a52485bda2 Fix all LegacyKeyValueFormat docker warnings 2024-10-07 15:22:28 -07:00
rkuo-danswer
79d37156c6 better logging for actions being taken inside document_by_cc_pair_cleanup (#2713) 2024-10-07 22:21:16 +00:00
rkuo-danswer
6fa8fabb47 add one more retry and wait a little longer to allow ourselves to recover from infra issues (#2714) 2024-10-07 22:17:49 +00:00
pablodanswer
4214a3a6e2 Inline code + effect clarity (#2715)
* cleaner code blocks + form context

* cleaner

* nit
2024-10-07 15:23:37 -07:00
rkuo-danswer
1a3469d2c5 check before using fetch_versioned_implementation because it logs warnings that confuse users. (#2708)
Renamed get_is_ee_version to is_ee_version to be less redundant
2024-10-07 21:37:56 +00:00
rkuo-danswer
30dc408028 rely on stdout redirection for supervisord logging (#2711) 2024-10-07 21:30:03 +00:00
Yuhong Sun
5d356cc971 Remove Perm Sync Script Dev (#2712) 2024-10-07 13:50:30 -07:00
pablodanswer
e4c7cfde42 Minor update to initial modal (#2571)
* minor update

* nit: pretty
2024-10-07 20:29:04 +00:00
pablodanswer
1900a390d8 Linting (#2704)
* effect cleanup

* remove unused imports

* remove unne

* remove unnecessary packages

* k

* temp

* minor
2024-10-07 20:21:07 +00:00
pablodanswer
150dcc2883 back button + popups (#2707)
* back button + popups

* remove logs
2024-10-07 20:10:58 +00:00
rkuo-danswer
3404c7eb1d Feature/background prune 2 (#2583)
* first cut at redis

* some new helper functions for the db

* ignore kombu tables in alembic migrations (used by celery)

* multiline commands for readability, add vespa_metadata_sync queue to worker

* typo fix

* fix returning tuple fields

* add constants

* fix _get_access_for_document

* docstrings!

* fix double function declaration and typing

* fix type hinting

* add a global redis pool

* Add get_document function

* use task_logger in various celery tasks

* add celeryconfig.py to simplify configuration. Will be used in a subsequent commit

* Add celery redis helper. used in a subsequent PR

* kombu warning getting spammy since celery is not self managing its queue in Postgres any more

* add last_modified and last_synced to documents

* fix task naming convention

* use celeryconfig.py

* the big one. adds queues and tasks, updates functions to use the queues with priorities, etc

* change vespa index log line to debug

* mypy fixes

* update alembic migration

* fix fence ordering, rename to "monitor", fix fetch_versioned_implementation call

* mypy

* switch to monotonic time

* fix startup dependencies on redis

* rebase alembic migration

* kombu cleanup - fail silently

* mypy

* add redis_host environment override

* update REDIS_HOST env var in docker-compose.dev.yml

* update the rest of the docker files

* in flight

* harden indexing-status endpoint against db changes happening in the background.  Needs further improvement but OK for now.

* allow no task syncs to run because we create certain objects with no entries but initially marked as out of date

* add back writing to vespa on indexing

* actually working connector deletion

* update contributing guide

* backporting fixes from background_deletion

* renaming cache to cache_volume

* add redis password to various deployments

* try setting up pr testing for helm

* fix indent

* hopefully this release version actually exists

* fix command line option to --chart-dirs

* fetch-depth 0

* edit values.yaml

* try setting ct working directory

* bypass testing only on change for now

* move files and lint them

* update helm testing

* some issues suggest using --config works

* add vespa repo

* add postgresql repo

* increase timeout

* try amd64 runner

* fix redis password reference

* add comment to helm chart testing workflow

* rename helm testing workflow to disable it

* adding clarifying comments

* address code review

* missed a file

* remove commented warning ... just not needed

* fix imports

* refactor to use update_single

* mypy fixes

* add vespa test

* multiple celery workers

* update logs as well and set prefetch multipliers appropriate to the worker intent

* add db refresh to connector deletion

* add some preliminary locking

* organize tasks into separate files

* celery auto associates tasks created inside another task, which bloats the result metadata considerably. trail=False prevents this.

* code review fixes

* move monitor_usergroup_taskset to ee, improve logging

* add multi workers to dev_run_background_jobs.py

* update supervisord with some recommended settings for celery

* name celery workers and shorten dev script prefixing

* add configurable sql alchemy engine settings on startup (needed for various intents like API server, different celery workers and tasks, etc)

* fix comments

* autoscale sqlalchemy pool size to celery concurrency (allow override later?)

* supervisord needs the percent symbols escaped

* use name as primary check, some minor refactoring and type hinting too.

* stash merge (may not function yet)

* remove dead code

* more cleanup

* remove dead file

* we shouldn't be checking for deletion attempts in the db any more

* print cc_pair_id

* print status on status mismatch again

* add logging when cc_pair isn't present

* don't indexing any ingestion type connectors, and don't pause any connectors that aren't active

* add more specific check for deletion completion

* remove flaky mediawiki test site

* move is_pruning

* remove unused code

* remove old function

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-10-07 18:16:17 +00:00
pablodanswer
64909d74f9 UX Cleanup (#2701)
* start

* shared iconlogo class

* clean out of place components

* nit
2024-10-07 17:33:08 +00:00
Yuhong Sun
83bc7d4656 DanswerBot Update (#2697) 2024-10-06 14:27:31 -07:00
pablodanswer
3206bb27ce update disabling logic (#2592) 2024-10-06 20:31:19 +00:00
pablodanswer
f189eda904 remove left-over memo (#2669) 2024-10-06 19:05:28 +00:00
pablodanswer
7aaf822430 Enable removal of reranking + navigate back to search settings (#2674)
* k

* nit
2024-10-06 19:05:17 +00:00
Yuhong Sun
0ff5180d7b Ensure tests don't use LLM (#2702) 2024-10-06 11:42:49 -07:00
evan-danswer
089c734f63 disabled llm when skip_gen_ai_answer_question set (#2687)
* disabled llm when skip_gen_ai_answer_question set

* added unit test

* typing
2024-10-06 18:10:02 +00:00
pablodanswer
0da736bed9 Tenant provisioning in the dataplane (#2694)
* add tenant provisioning to data plane

* minor typing update

* ensure tenant router included

* proper auth check

* update disabling logic

* validated basic provisioning

* use new kv store
2024-10-06 04:08:35 +00:00
Chris Weaver
e00f4678df Add option to adjust pool size (#2695) 2024-10-05 23:37:48 +00:00
pablodanswer
e56fd43ba6 cors update (#2686) 2024-10-05 23:08:28 +00:00
pablodanswer
28e65669b4 add multi tenant alembic (#2589) 2024-10-05 21:59:15 +00:00
pablodanswer
493c3d7314 Add only multi tenant dependency injection (#2588)
* add only dependency injection

* quick typing fix

* additional non-dependency context

* update nits
2024-10-05 21:08:41 +00:00
pablodanswer
b04e9e9b67 Improved api key forms + fix non-submittable azure (#2654) 2024-10-04 19:29:45 -07:00
rkuo-danswer
3755e575a5 harden connections to redis (#2677)
* set broker_connection_retry_on_startup to silence deprecation warning (we're OK with retrying on startup)

* env var for CELERY_BROKER_POOL_LIMIT

* add redis retry on timeout and health check interval

* set socket_keepalive = True

* remove shadow declaration of REDIS_HEALTH_CHECK_INTERVAL, add socket_keepalive_options where possible

* fix mypy complaint

* pass through vars in docker compose

* remove extra '='

* wrap in a try
2024-10-04 16:00:48 +00:00
rkuo-danswer
63655cfbed update_single should be optimized for a single call now (#2671)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-10-04 15:43:04 +00:00
rkuo-danswer
7f788e4b1e bump celery to 5.5.0b4 (#2681) 2024-10-04 05:54:32 +00:00
Chris Weaver
1362d4b583 Allow config of background concurrency (#2648)
* Allow config of background concurrency

* Add comment

* Fix light worker

* use backslashes to continue lines in supervisord with bash

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@danswer.ai>
2024-10-04 00:55:28 +00:00
rkuo-danswer
4f47004d47 disable another flaky assert (#2678) 2024-10-04 00:25:46 +00:00
rkuo-danswer
3fdd233e84 delete directly via selection instead of making multiple calls to get chunk ids and delete each one (#2666) 2024-10-03 01:57:25 +00:00
Yuhong Sun
0c54d9d57d Unstructured Update Copy (#2668) 2024-10-02 17:48:11 -07:00
hagen-danswer
c2088602e1 Implement source testing framework + Slack (#2650)
* Added permission sync tests for Slack

* moved folders

* prune test + mypy

* added wait for indexing to cc_pair creation

* commented out check

* should fix other tests

* added slack channel pool

* fixed everything and mypy

* reduced flake
2024-10-02 23:16:07 +00:00
Chris Weaver
b3c367d09c [tiny] adjust user group sync log (#2664) 2024-10-02 18:01:40 +00:00
pablodanswer
457d32fef0 add clarity around assistants and names (#2663) 2024-10-02 18:00:06 +00:00
pablodanswer
af187c6cfe Better virtualization (#2653) 2024-10-02 11:14:59 -07:00
rkuo-danswer
a0235b7b7b replace trivy download endpoint due to db download flakiness on their en… (#2661)
* disable trivy for the moment due to db download flakiness on their end causing the action to fail

* try hardcoding to amazon registry as others have suggested
2024-10-02 17:13:19 +00:00
pablodanswer
a30de693cb Clean, memoized assistant ordering (#2655)
* updated refresh

* memoization and so on

* nit

* build issue
2024-10-02 16:15:54 +00:00
pablodanswer
07aeea69e7 Dupe welcome modal logic (#2656) 2024-10-01 20:11:39 -07:00
Evan Lohn
bd40328a73 fix typo 2024-10-01 20:10:37 -07:00
Chris Weaver
b8232e0681 Update litellm to fix bedrock models (#2649) 2024-10-01 20:09:57 -07:00
Yuhong Sun
fffb9c155a Redis Cache for KV Store (#2603)
* k

* k

* k

* k
2024-10-01 18:31:18 +00:00
rkuo-danswer
f513c5bbed sync up when checks run with branch protection required checks (#2628) 2024-10-01 17:59:10 +00:00
pablodanswer
9a4e51a18e add default model + minor fixes (#2638)
* add default model + minor fixes

* fix build

* minor additional fix

* build fix
2024-10-01 17:43:43 +00:00
rkuo-danswer
2f2fc08553 raise redis connections and using blocking connection pool for more d… (#2635)
* raise redis connections and using blocking connection pool for more deterministic behavior

* improve comment
2024-10-01 17:27:17 +00:00
pablodanswer
c68c6fdc44 welcome flow 2024-10-01 10:34:53 -07:00
hagen-danswer
834c76e30a Added quotes to project name to handle reserved words (#2639) 2024-10-01 10:32:41 -07:00
rkuo-danswer
ec02665ffa run the nightly tag overnight relative to pacific time (#2637) 2024-10-01 16:36:40 +00:00
pablodanswer
3fa1b18306 update nav link name (#2643)
* update nav link name

* underscore -> dash
2024-10-01 16:34:30 +00:00
Chris Weaver
c9bdf4c443 Update CONTRIBUTING.md 2024-10-01 08:46:25 -07:00
Yuhong Sun
e229d27734 Unstructured UI (#2636)
* checkpoint

* k

* k

* need frontend

* add api key check + ui component

* add proper ports + icons + functions

* k

* k

* k

---------

Co-authored-by: pablodanswer <pablo@danswer.ai>
2024-10-01 04:50:03 +00:00
rkuo-danswer
140c5b3957 don't push integration testing docker images (#2584)
* experiment with build and no push

* use slightly more descriptive and consistent tags and names

* name integration test workflow consistently with other workflows

* put the tag back

* try runs-on s3 backend

* try adding runs-on cache

* add with key

* add a dummy path

* forget about multiline

* maybe we don't need runs-on cache immediately

* lower ram slightly, name test with a version bump

* don't need to explicitly include runs-on/cache for docker caching

* comment out flaky portion of knowledge chat test

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-10-01 01:00:47 +00:00
Chris Weaver
3e511497d2 Fix overflow of prompt library table (#2606) 2024-09-30 15:31:12 +00:00
hagen-danswer
b0056907fb Added permissions syncing for slack (#2602)
* Added permissions syncing for slack

* add no email case handling

* mypy fixes

* frontend

* minor cleanup

* param tweak
2024-09-30 15:14:43 +00:00
Chris Weaver
728a41a35a Add heartbeat to indexing (#2595) 2024-09-29 19:26:40 -07:00
Chris Weaver
ef8dda2d47 Rely on PVC (#2604) 2024-09-29 17:30:39 -07:00
pablodanswer
15283b3140 prevent nextFormStep unless credential fully set up (#2599) 2024-09-29 22:47:45 +00:00
Chris Weaver
e159b2e947 Fix default assistant (#2600)
* Fix default assistant

* Remove log

* Add newline
2024-09-29 22:47:14 +00:00
Jeff Knapp
9155800fab EKS initial deployment (#2154)
Co-authored-by: Chris Weaver <25087905+Weves@users.noreply.github.com>
2024-09-29 15:51:31 -07:00
pablodanswer
a392ef0541 Show transition card if no connectors (#2597)
* show transition card if no connectors

* squash

* update apos
2024-09-29 22:35:41 +00:00
Yuhong Sun
5679f0af61 Minor Query History Fix (#2594) 2024-09-29 10:54:08 -07:00
rkuo-danswer
ff8db71cb5 don't write a nightly tag to the same commit more than once (#2585)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-29 10:36:08 -07:00
hagen-danswer
1cff2b82fd Global Curator Fix + Testing (#2591)
* Global Curator Fix

* test fix
2024-09-28 20:14:39 +00:00
Chris Weaver
50dd3c8beb Add size limit to jira tickets (#2586) 2024-09-28 12:49:13 -07:00
hagen-danswer
66a459234d Minor role display refactor (#2578) 2024-09-27 16:50:03 +00:00
rkuo-danswer
19e57474dc Feature/xenforo (#2497)
* Xenforo forum parser support

* clarify ssl cert reqs

* missed a file

* add isLoadState function, fix up xenforo for data driven connector approach

* fixing a new edge case to skip an unexpected parsed element

* change documentsource to xenforo

* make doc id unique and comment what's happening

* remove stray log line

* address code review

---------

Co-authored-by: sime2408 <simun.sunjic@gmail.com>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-27 16:36:05 +00:00
rkuo-danswer
f9638f2ea5 try user deploy key approach to tagging (#2575)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-27 16:04:55 +00:00
rkuo-danswer
fbf51b70d0 Feature/celery multi (#2470)
* first cut at redis

* some new helper functions for the db

* ignore kombu tables in alembic migrations (used by celery)

* multiline commands for readability, add vespa_metadata_sync queue to worker

* typo fix

* fix returning tuple fields

* add constants

* fix _get_access_for_document

* docstrings!

* fix double function declaration and typing

* fix type hinting

* add a global redis pool

* Add get_document function

* use task_logger in various celery tasks

* add celeryconfig.py to simplify configuration. Will be used in a subsequent commit

* Add celery redis helper. used in a subsequent PR

* kombu warning getting spammy since celery is not self managing its queue in Postgres any more

* add last_modified and last_synced to documents

* fix task naming convention

* use celeryconfig.py

* the big one. adds queues and tasks, updates functions to use the queues with priorities, etc

* change vespa index log line to debug

* mypy fixes

* update alembic migration

* fix fence ordering, rename to "monitor", fix fetch_versioned_implementation call

* mypy

* switch to monotonic time

* fix startup dependencies on redis

* rebase alembic migration

* kombu cleanup - fail silently

* mypy

* add redis_host environment override

* update REDIS_HOST env var in docker-compose.dev.yml

* update the rest of the docker files

* in flight

* harden indexing-status endpoint against db changes happening in the background.  Needs further improvement but OK for now.

* allow no task syncs to run because we create certain objects with no entries but initially marked as out of date

* add back writing to vespa on indexing

* actually working connector deletion

* update contributing guide

* backporting fixes from background_deletion

* renaming cache to cache_volume

* add redis password to various deployments

* try setting up pr testing for helm

* fix indent

* hopefully this release version actually exists

* fix command line option to --chart-dirs

* fetch-depth 0

* edit values.yaml

* try setting ct working directory

* bypass testing only on change for now

* move files and lint them

* update helm testing

* some issues suggest using --config works

* add vespa repo

* add postgresql repo

* increase timeout

* try amd64 runner

* fix redis password reference

* add comment to helm chart testing workflow

* rename helm testing workflow to disable it

* adding clarifying comments

* address code review

* missed a file

* remove commented warning ... just not needed

* fix imports

* refactor to use update_single

* mypy fixes

* add vespa test

* multiple celery workers

* update logs as well and set prefetch multipliers appropriate to the worker intent

* add db refresh to connector deletion

* add some preliminary locking

* organize tasks into separate files

* celery auto associates tasks created inside another task, which bloats the result metadata considerably. trail=False prevents this.

* code review fixes

* move monitor_usergroup_taskset to ee, improve logging

* add multi workers to dev_run_background_jobs.py

* update supervisord with some recommended settings for celery

* name celery workers and shorten dev script prefixing

* add configurable sql alchemy engine settings on startup (needed for various intents like API server, different celery workers and tasks, etc)

* fix comments

* autoscale sqlalchemy pool size to celery concurrency (allow override later?)

* supervisord needs the percent symbols escaped

* use name as primary check, some minor refactoring and type hinting too.

* addressing code review

* fix import

* fix prune_documents_task references

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-27 00:50:55 +00:00
hagen-danswer
b97cc01bb2 Added confluence permission syncing (#2537)
* Added confluence permission syncing

* seperated out group and doc syncing

* minorbugfix and mypy

* added frontend and fixed bug

* Minor refactor

* dealth with confluence rate limits!

* mypy fixes!!!

* addressed yuhong feedback

* primary key fix
2024-09-26 22:10:41 +00:00
rkuo-danswer
6d48fd5d99 clamp retry to max_delay (#2570) 2024-09-26 21:56:46 +00:00
Chris Weaver
1f61447b4b Add open in new tab for custom links (#2568) 2024-09-26 20:01:35 +00:00
rkuo-danswer
deee2b3513 push to docker latest when git tag contains "latest", and tag nightly (#2564)
* comment docker tag latest

* make latest builds contingent on a "latest" keyword in the tag

* v4 checkout

* nightly tag push
2024-09-26 17:40:13 +00:00
hagen-danswer
b73d66c84a Cleaned up foreign key cleanup for user group deletion (#2559)
* cleaned up fk cleanup for user group deletion

* added test for user group deletion
2024-09-26 03:38:01 +00:00
rkuo-danswer
c5a61f4820 Feature/test pruning (#2556)
* add test to exercise pruning

* add prettierignore

* mypy fix

* mypy again

* try getting all the env vars set up correctly

* fix ports and hostnames
2024-09-25 23:34:13 +00:00
pablodanswer
ea4a3cbf86 update folder list (#2563) 2024-09-25 16:25:45 -07:00
rkuo-danswer
166514cedf ssl_ca_certs should default to None, not "". (#2560)
* ssl_ca_certs should default to None, not "".

otherwise, if ssl is enabled it will look for the cert on an empty path and fail.

* mypy fix
2024-09-25 19:56:21 +00:00
pablodanswer
be50ae1e71 flex none (#2558) 2024-09-25 10:19:37 -07:00
pablodanswer
f89504ec53 Update some ux edge cases (#2545)
* update some ux edge cases

* update some formatting / ports
2024-09-25 16:46:43 +00:00
trial-danswer
6b3213b1e4 fix typo (#2543)
* fix typo

* Update EmbeddingFormPage.tsx

---------

Co-authored-by: danswer-trial <danswer-trial@danswer-trials-MacBook-Pro.local>
Co-authored-by: rkuo-danswer <rkuo@danswer.ai>
2024-09-25 01:25:46 +00:00
Chris Weaver
48577bf0e4 Allow = in tag filter (#2548)
* Allow = in tag filter

* Rename func
2024-09-24 21:37:35 +00:00
pablodanswer
c59d1ff0a5 Update merge queue logic (#2554)
* update merge queue logic

* remove space
2024-09-24 18:45:05 +00:00
pablodanswer
ba38dec592 ensure default_assistant passed through 2024-09-24 11:35:19 -07:00
pablodanswer
f5adc3063e Update theming (#2552)
* update theming

* update

* update theming
2024-09-24 18:01:08 +00:00
hagen-danswer
8cfe80c53a Added doc_set__user_group cleanup for user_group deletion (#2551) 2024-09-24 16:09:52 +00:00
ThomaciousD
487250320b fix saml email login upsert issue 2024-09-24 07:42:08 -07:00
rkuo-danswer
c8d13922a9 rename classes and ignore deprecation warnings we mostly don't have c… (#2546)
* rename classes and ignore deprecation warnings we mostly don't have control over

* copy pytest.ini

* ignore CryptographyDeprecationWarning

* fully qualify the warning
2024-09-24 00:21:42 +00:00
rkuo-danswer
cb75449cec Feature/runs on 2 (#2547)
* test self hosted runner

* update more docker builds with self hosted runner

* convert everything to runs-on (except web container)

* try upping the RAM for future flake proofing
2024-09-23 23:46:20 +00:00
rkuo-danswer
b66514cd21 test self hosted runner (#2541)
* test self hosted runner

* update more docker builds with self hosted runner

* convert everything to runs-on (except web container)
2024-09-23 21:57:23 +00:00
Chris Weaver
77650c9ee3 Fix misc tool call errors (#2544)
* Fix misc tool call errors

* Fix middleware
2024-09-23 21:00:48 +00:00
pablodanswer
316b6b99ea Tooling testing (#2533)
* add initial testing

* add custom tool testing

* update ports

* update tests - additional coverage

* update types
2024-09-23 20:09:01 +00:00
Chris Weaver
34c2aa0860 Support svg navigation items (#2542)
* Support SVG nav items

* Handle specifying custom SVGs for navbar

* Add comment

* More comment

* More comment
2024-09-23 13:22:20 -07:00
437 changed files with 41298 additions and 4579 deletions

View File

@@ -32,16 +32,20 @@ inputs:
description: 'Cache destinations'
required: false
retry-wait-time:
description: 'Time to wait before retry in seconds'
description: 'Time to wait before attempt 2 in seconds'
required: false
default: '5'
default: '60'
retry-wait-time-2:
description: 'Time to wait before attempt 3 in seconds'
required: false
default: '120'
runs:
using: "composite"
steps:
- name: Build and push Docker image (First Attempt)
- name: Build and push Docker image (Attempt 1 of 3)
id: buildx1
uses: docker/build-push-action@v5
uses: docker/build-push-action@v6
continue-on-error: true
with:
context: ${{ inputs.context }}
@@ -54,16 +58,17 @@ runs:
cache-from: ${{ inputs.cache-from }}
cache-to: ${{ inputs.cache-to }}
- name: Wait to retry
- name: Wait before attempt 2
if: steps.buildx1.outcome != 'success'
run: |
echo "First attempt failed. Waiting ${{ inputs.retry-wait-time }} seconds before retry..."
sleep ${{ inputs.retry-wait-time }}
shell: bash
- name: Build and push Docker image (Retry Attempt)
- name: Build and push Docker image (Attempt 2 of 3)
id: buildx2
if: steps.buildx1.outcome != 'success'
uses: docker/build-push-action@v5
uses: docker/build-push-action@v6
with:
context: ${{ inputs.context }}
file: ${{ inputs.file }}
@@ -74,3 +79,31 @@ runs:
tags: ${{ inputs.tags }}
cache-from: ${{ inputs.cache-from }}
cache-to: ${{ inputs.cache-to }}
- name: Wait before attempt 3
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success'
run: |
echo "Second attempt failed. Waiting ${{ inputs.retry-wait-time-2 }} seconds before retry..."
sleep ${{ inputs.retry-wait-time-2 }}
shell: bash
- name: Build and push Docker image (Attempt 3 of 3)
id: buildx3
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success'
uses: docker/build-push-action@v6
with:
context: ${{ inputs.context }}
file: ${{ inputs.file }}
platforms: ${{ inputs.platforms }}
pull: ${{ inputs.pull }}
push: ${{ inputs.push }}
load: ${{ inputs.load }}
tags: ${{ inputs.tags }}
cache-from: ${{ inputs.cache-from }}
cache-to: ${{ inputs.cache-to }}
- name: Report failure
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success' && steps.buildx3.outcome != 'success'
run: |
echo "All attempts failed. Possible transient infrastucture issues? Try again later or inspect logs for details."
shell: bash

View File

@@ -7,16 +7,17 @@ on:
env:
REGISTRY_IMAGE: danswer/danswer-backend
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
# TODO: make this a matrix build like the web containers
runs-on:
group: amd64-image-builders
# TODO: investigate a matrix build like the web container
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v2
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -31,7 +32,7 @@ jobs:
run: |
sudo apt-get update
sudo apt-get install -y build-essential
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v5
with:
@@ -41,12 +42,20 @@ jobs:
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.REGISTRY_IMAGE }}:latest
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}

View File

@@ -5,14 +5,18 @@ on:
tags:
- '*'
env:
REGISTRY_IMAGE: danswer/danswer-model-server
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
runs-on:
group: amd64-image-builders
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v2
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -31,13 +35,21 @@ jobs:
platforms: linux/amd64,linux/arm64
push: true
tags: |
danswer/danswer-model-server:${{ github.ref_name }}
danswer/danswer-model-server:latest
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'

View File

@@ -7,11 +7,15 @@ on:
env:
REGISTRY_IMAGE: danswer/danswer-web-server
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build:
runs-on:
group: ${{ matrix.platform == 'linux/amd64' && 'amd64-image-builders' || 'arm64-image-builders' }}
runs-on:
- runs-on
- runner=${{ matrix.platform == 'linux/amd64' && '8cpu-linux-x64' || '8cpu-linux-arm64' }}
- run-id=${{ github.run_id }}
- tag=platform-${{ matrix.platform }}
strategy:
fail-fast: false
matrix:
@@ -35,7 +39,7 @@ jobs:
images: ${{ env.REGISTRY_IMAGE }}
tags: |
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
type=raw,value=${{ env.REGISTRY_IMAGE }}:latest
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -112,8 +116,16 @@ jobs:
run: |
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'

View File

@@ -1,3 +1,6 @@
# This workflow is set up to be manually triggered via the GitHub Action tab.
# Given a version, it will tag those backend and webserver images as "latest".
name: Tag Latest Version
on:
@@ -9,7 +12,9 @@ on:
jobs:
tag:
runs-on: ubuntu-latest
# See https://runs-on.com/runners/linux/
# use a lower powered instance since this just does i/o to docker hub
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1

View File

@@ -1,19 +1,23 @@
name: Run Integration Tests
name: Run Integration Tests v2
concurrency:
group: Run-Integration-Tests-${{ github.head_ref }}
group: Run-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
merge_group:
pull_request:
branches: [ main ]
branches:
- main
- 'release/**'
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
jobs:
integration-tests:
runs-on: Amd64
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
@@ -27,25 +31,35 @@ jobs:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# NOTE: we don't need to build the Web Docker image since it's not used
# during the IT for now. We have a separate action to verify it builds
# succesfully
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# We don't need to build the Web Docker image since it's not yet used
# in the integration tests. We have a separate action to verify that it builds
# successfully.
- name: Pull Web Docker image
run: |
docker pull danswer/danswer-web-server:latest
docker tag danswer/danswer-web-server:latest danswer/danswer-web-server:it
docker tag danswer/danswer-web-server:latest danswer/danswer-web-server:test
# we use the runs-on cache for docker builds
# in conjunction with runs-on runners, it has better speed and unlimited caching
# https://runs-on.com/caching/s3-cache-for-github-actions/
# https://runs-on.com/caching/docker/
# https://github.com/moby/buildkit#s3-cache-experimental
# images are built and run locally for testing purposes. Not pushed.
- name: Build Backend Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64
tags: danswer/danswer-backend:it
cache-from: type=registry,ref=danswer/danswer-backend:it
cache-to: |
type=registry,ref=danswer/danswer-backend:it,mode=max
type=inline
tags: danswer/danswer-backend:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Model Server Docker image
uses: ./.github/actions/custom-build-and-push
@@ -53,11 +67,11 @@ jobs:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64
tags: danswer/danswer-model-server:it
cache-from: type=registry,ref=danswer/danswer-model-server:it
cache-to: |
type=registry,ref=danswer/danswer-model-server:it,mode=max
type=inline
tags: danswer/danswer-model-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build integration test Docker image
uses: ./.github/actions/custom-build-and-push
@@ -65,11 +79,11 @@ jobs:
context: ./backend
file: ./backend/tests/integration/Dockerfile
platforms: linux/amd64
tags: danswer/integration-test-runner:it
cache-from: type=registry,ref=danswer/integration-test-runner:it
cache-to: |
type=registry,ref=danswer/integration-test-runner:it,mode=max
type=inline
tags: danswer/danswer-integration:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Start Docker containers
run: |
@@ -78,7 +92,7 @@ jobs:
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=it \
IMAGE_TAG=test \
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
id: start_docker
@@ -120,6 +134,7 @@ jobs:
run: |
echo "Running integration tests..."
docker run --rm --network danswer-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
@@ -128,7 +143,9 @@ jobs:
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
danswer/integration-test-runner:it
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
danswer/danswer-integration:test
continue-on-error: true
id: run_tests

View File

@@ -12,7 +12,8 @@ on:
jobs:
lint-test:
runs-on: Amd64
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
# fetch-depth 0 is required for helm/chart-testing-action
steps:

View File

@@ -3,11 +3,14 @@ name: Python Checks
on:
merge_group:
pull_request:
branches: [ main ]
branches:
- main
- 'release/**'
jobs:
mypy-check:
runs-on: ubuntu-latest
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code

View File

@@ -15,10 +15,14 @@ env:
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
# Jira
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
jobs:
connectors-check:
runs-on: ubuntu-latest
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
env:
PYTHONPATH: ./backend

View File

@@ -0,0 +1,58 @@
name: Connector Tests
on:
schedule:
# This cron expression runs the job daily at 16:00 UTC (9am PT)
- cron: "0 16 * * *"
env:
# Bedrock
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
# OpenAI
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
env:
PYTHONPATH: ./backend
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
cache: "pip"
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: |
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/llm
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/embedding
- name: Alert on Failure
if: failure() && github.event_name == 'schedule'
env:
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
run: |
curl -X POST \
-H 'Content-type: application/json' \
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
$SLACK_WEBHOOK

View File

@@ -3,11 +3,14 @@ name: Python Unit Tests
on:
merge_group:
pull_request:
branches: [ main ]
branches:
- main
- 'release/**'
jobs:
backend-check:
runs-on: ubuntu-latest
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
env:
PYTHONPATH: ./backend

View File

@@ -1,6 +1,6 @@
name: Quality Checks PR
concurrency:
group: Quality-Checks-PR-${{ github.head_ref }}
group: Quality-Checks-PR-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
@@ -9,7 +9,8 @@ on:
jobs:
quality-checks:
runs-on: ubuntu-latest
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- uses: actions/checkout@v4
with:

54
.github/workflows/tag-nightly.yml vendored Normal file
View File

@@ -0,0 +1,54 @@
name: Nightly Tag Push
on:
schedule:
- cron: '0 10 * * *' # Runs every day at 2 AM PST / 3 AM PDT / 10 AM UTC
permissions:
contents: write # Allows pushing tags to the repository
jobs:
create-and-push-tag:
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
# actions using GITHUB_TOKEN cannot trigger another workflow, but we do want this to trigger docker pushes
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
# implement here which needs an actual user's deploy key
- name: Checkout code
uses: actions/checkout@v4
with:
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
- name: Set up Git user
run: |
git config user.name "Richard Kuo [bot]"
git config user.email "rkuo[bot]@danswer.ai"
- name: Check for existing nightly tag
id: check_tag
run: |
if git tag --points-at HEAD --list "nightly-latest*" | grep -q .; then
echo "A tag starting with 'nightly-latest' already exists on HEAD."
echo "tag_exists=true" >> $GITHUB_OUTPUT
else
echo "No tag starting with 'nightly-latest' exists on HEAD."
echo "tag_exists=false" >> $GITHUB_OUTPUT
fi
# don't tag again if HEAD already has a nightly-latest tag on it
- name: Create Nightly Tag
if: steps.check_tag.outputs.tag_exists == 'false'
env:
DATE: ${{ github.run_id }}
run: |
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
echo "Creating tag: $TAG_NAME"
git tag $TAG_NAME
- name: Push Tag
if: steps.check_tag.outputs.tag_exists == 'false'
run: |
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
git push origin $TAG_NAME

1
.prettierignore Normal file
View File

@@ -0,0 +1 @@
backend/tests/integration/tests/pruning/website

View File

@@ -22,7 +22,7 @@ Your input is vital to making sure that Danswer moves in the right direction.
Before starting on implementation, please raise a GitHub issue.
And always feel free to message us (Chris Weaver / Yuhong Sun) on
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-2afut44lv-Rw3kSWu6_OmdAXRpCv80DQ) /
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ) /
[Discord](https://discord.gg/TDJ59cGV2X) directly about anything at all.

View File

@@ -101,7 +101,7 @@ COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connect
# Put logo in assets
COPY ./assets /app/assets
ENV PYTHONPATH /app
ENV PYTHONPATH=/app
# Default command which does nothing
# This container is used by api server and background which specify their own CMD

View File

@@ -55,6 +55,6 @@ COPY ./shared_configs /app/shared_configs
# Model Server main code
COPY ./model_server /app/model_server
ENV PYTHONPATH /app
ENV PYTHONPATH=/app
CMD ["uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000"]

View File

@@ -9,9 +9,9 @@ from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine
from celery.backends.database.session import ResultModelBase # type: ignore
from sqlalchemy.schema import SchemaItem
from sqlalchemy.sql import text
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
# Alembic Config object
config = context.config
# Interpret the config file for Python logging.
@@ -21,16 +21,26 @@ if config.config_file_name is not None and config.attributes.get(
):
fileConfig(config.config_file_name)
# add your model's MetaData object here
# Add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = [Base.metadata, ResultModelBase.metadata]
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def get_schema_options() -> tuple[str, bool]:
x_args_raw = context.get_x_argument()
x_args = {}
for arg in x_args_raw:
for pair in arg.split(","):
if "=" in pair:
key, value = pair.split("=", 1)
x_args[key] = value
schema_name = x_args.get("schema", "public")
create_schema = x_args.get("create_schema", "true").lower() == "true"
return schema_name, create_schema
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
@@ -54,17 +64,20 @@ def run_migrations_offline() -> None:
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = build_connection_string()
schema, _ = get_schema_options()
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
include_object=include_object,
dialect_opts={"paramstyle": "named"},
version_table_schema=schema,
include_schemas=True,
)
with context.begin_transaction():
@@ -72,22 +85,28 @@ def run_migrations_offline() -> None:
def do_run_migrations(connection: Connection) -> None:
schema, create_schema = get_schema_options()
if create_schema:
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"'))
connection.execute(text("COMMIT"))
connection.execute(text(f'SET search_path TO "{schema}"'))
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore
include_object=include_object,
) # type: ignore
version_table_schema=schema,
include_schemas=True,
compare_type=True,
compare_server_default=True,
)
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
"""In this scenario we need to create an Engine
and associate a connection with the context.
"""
"""Run migrations in 'online' mode."""
connectable = create_async_engine(
build_connection_string(),
poolclass=pool.NullPool,
@@ -101,7 +120,6 @@ async def run_async_migrations() -> None:
def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
asyncio.run(run_async_migrations())

View File

@@ -0,0 +1,46 @@
"""fix_user__external_user_group_id_fk
Revision ID: 46b7a812670f
Revises: f32615f71aeb
Create Date: 2024-09-23 12:58:03.894038
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "46b7a812670f"
down_revision = "f32615f71aeb"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop the existing primary key
op.drop_constraint(
"user__external_user_group_id_pkey",
"user__external_user_group_id",
type_="primary",
)
# Add the new composite primary key
op.create_primary_key(
"user__external_user_group_id_pkey",
"user__external_user_group_id",
["user_id", "external_user_group_id", "cc_pair_id"],
)
def downgrade() -> None:
# Drop the composite primary key
op.drop_constraint(
"user__external_user_group_id_pkey",
"user__external_user_group_id",
type_="primary",
)
# Delete all entries from the table
op.execute("DELETE FROM user__external_user_group_id")
# Recreate the original primary key on user_id
op.create_primary_key(
"user__external_user_group_id_pkey", "user__external_user_group_id", ["user_id"]
)

View File

@@ -9,7 +9,7 @@ import json
from typing import cast
from alembic import op
import sqlalchemy as sa
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.key_value_store.factory import get_kv_store
# revision identifiers, used by Alembic.
revision = "703313b75876"
@@ -54,9 +54,7 @@ def upgrade() -> None:
)
try:
settings_json = cast(
str, get_dynamic_config_store().load("token_budget_settings")
)
settings_json = cast(str, get_kv_store().load("token_budget_settings"))
settings = json.loads(settings_json)
is_enabled = settings.get("enable_token_budget", False)
@@ -71,7 +69,7 @@ def upgrade() -> None:
)
# Delete the dynamic config
get_dynamic_config_store().delete("token_budget_settings")
get_kv_store().delete("token_budget_settings")
except Exception:
# Ignore if the dynamic config is not found

View File

@@ -0,0 +1,27 @@
"""add last_pruned to the connector_credential_pair table
Revision ID: ac5eaac849f9
Revises: 52a219fb5233
Create Date: 2024-09-10 15:04:26.437118
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "ac5eaac849f9"
down_revision = "46b7a812670f"
branch_labels = None
depends_on = None
def upgrade() -> None:
# last pruned represents the last time the connector was pruned
op.add_column(
"connector_credential_pair",
sa.Column("last_pruned", sa.DateTime(timezone=True), nullable=True),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "last_pruned")

View File

@@ -1,20 +1,20 @@
from typing import cast
from danswer.configs.constants import KV_USER_STORE_KEY
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.dynamic_configs.interface import JSON_ro
from danswer.key_value_store.factory import get_kv_store
from danswer.key_value_store.interface import JSON_ro
from danswer.key_value_store.interface import KvKeyNotFoundError
def get_invited_users() -> list[str]:
try:
store = get_dynamic_config_store()
store = get_kv_store()
return cast(list, store.load(KV_USER_STORE_KEY))
except ConfigNotFoundError:
except KvKeyNotFoundError:
return list()
def write_invited_users(emails: list[str]) -> int:
store = get_dynamic_config_store()
store = get_kv_store()
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
return len(emails)

View File

@@ -4,29 +4,29 @@ from typing import cast
from danswer.auth.schemas import UserRole
from danswer.configs.constants import KV_NO_AUTH_USER_PREFERENCES_KEY
from danswer.dynamic_configs.store import ConfigNotFoundError
from danswer.dynamic_configs.store import DynamicConfigStore
from danswer.key_value_store.store import KeyValueStore
from danswer.key_value_store.store import KvKeyNotFoundError
from danswer.server.manage.models import UserInfo
from danswer.server.manage.models import UserPreferences
def set_no_auth_user_preferences(
store: DynamicConfigStore, preferences: UserPreferences
store: KeyValueStore, preferences: UserPreferences
) -> None:
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.model_dump())
def load_no_auth_user_preferences(store: DynamicConfigStore) -> UserPreferences:
def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
try:
preferences_data = cast(
Mapping[str, Any], store.load(KV_NO_AUTH_USER_PREFERENCES_KEY)
)
return UserPreferences(**preferences_data)
except ConfigNotFoundError:
except KvKeyNotFoundError:
return UserPreferences(chosen_assistants=None, default_model=None)
def fetch_no_auth_user(store: DynamicConfigStore) -> UserInfo:
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:
return UserInfo(
id="__no_auth_user__",
email="anonymous@danswer.ai",

View File

@@ -8,6 +8,7 @@ from email.mime.text import MIMEText
from typing import Optional
from typing import Tuple
import jwt
from email_validator import EmailNotValidError
from email_validator import validate_email
from fastapi import APIRouter
@@ -37,8 +38,10 @@ from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
from danswer.auth.schemas import UserUpdate
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DATA_PLANE_SECRET
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import EMAIL_FROM
from danswer.configs.app_configs import EXPECTED_API_KEY
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.configs.app_configs import SMTP_PASS
@@ -342,7 +345,6 @@ def get_database_strategy(
strategy = DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
)
return strategy
@@ -505,3 +507,28 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
def get_default_admin_user_emails_() -> list[str]:
# No default seeding available for Danswer MIT
return []
async def control_plane_dep(request: Request) -> None:
api_key = request.headers.get("X-API-KEY")
if api_key != EXPECTED_API_KEY:
logger.warning("Invalid API key")
raise HTTPException(status_code=401, detail="Invalid API key")
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
logger.warning("Invalid authorization header")
raise HTTPException(status_code=401, detail="Invalid authorization header")
token = auth_header.split(" ")[1]
try:
payload = jwt.decode(token, DATA_PLANE_SECRET, algorithms=["HS256"])
if payload.get("scope") != "tenant:create":
logger.warning("Insufficient permissions")
raise HTTPException(status_code=403, detail="Insufficient permissions")
except jwt.ExpiredSignatureError:
logger.warning("Token has expired")
raise HTTPException(status_code=401, detail="Token has expired")
except jwt.InvalidTokenError:
logger.warning("Invalid token")
raise HTTPException(status_code=401, detail="Invalid token")

File diff suppressed because it is too large Load Diff

View File

@@ -21,6 +21,7 @@ from danswer.db.document import (
)
from danswer.db.document_set import construct_document_select_by_docset
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import global_version
class RedisObjectHelper(ABC):
@@ -172,6 +173,9 @@ class RedisUserGroup(RedisObjectHelper):
async_results = []
if not global_version.is_ee_version():
return 0
try:
construct_document_select_by_usergroup = fetch_versioned_implementation(
"danswer.db.user_group",
@@ -343,6 +347,125 @@ class RedisConnectorDeletion(RedisObjectHelper):
return len(async_results)
class RedisConnectorPruning(RedisObjectHelper):
"""Celery will kick off a long running generator task to crawl the connector and
find any missing docs, which will each then get a new cleanup task. The progress of
those tasks will then be monitored to completion.
Example rough happy path order:
Check connectorpruning_fence_1
Send generator task with id connectorpruning+generator_1_{uuid}
generator runs connector with callbacks that increment connectorpruning_generator_progress_1
generator creates many subtasks with id connectorpruning+sub_1_{uuid}
in taskset connectorpruning_taskset_1
on completion, generator sets connectorpruning_generator_complete_1
celery postrun removes subtasks from taskset
monitor beat task cleans up when taskset reaches 0 items
"""
PREFIX = "connectorpruning"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire pruning process
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
SUBTASK_PREFIX = PREFIX + "+sub"
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # a signal that contains generator progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # a signal that the generator has finished
def __init__(self, id: int) -> None:
"""id: the cc_pair_id of the connector credential pair"""
super().__init__(id)
self.documents_to_prune: set[str] = set()
@property
def generator_task_id_prefix(self) -> str:
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
@property
def generator_progress_key(self) -> str:
# example: connectorpruning_generator_progress_1
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
@property
def generator_complete_key(self) -> str:
# example: connectorpruning_generator_complete_1
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
@property
def subtask_id_prefix(self) -> str:
return f"{self.SUBTASK_PREFIX}_{self._id}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
if not cc_pair:
return None
for doc_id in self.documents_to_prune:
current_time = time.monotonic()
if lock and current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.subtask_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(self.taskset_key, custom_task_id)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"document_by_cc_pair_cleanup_task",
kwargs=dict(
document_id=doc_id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
def is_pruning(self, db_session: Session, redis_client: Redis) -> bool:
"""A single example of a helper method being refactored into the redis helper"""
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=self._id, db_session=db_session
)
if not cc_pair:
raise ValueError(f"cc_pair_id {self._id} does not exist.")
if redis_client.exists(self.fence_key):
return True
return False
def celery_get_queue_length(queue: str, r: Redis) -> int:
"""This is a redis specific way to get the length of a celery queue.
It is priority aware and knows how to count across the multiple redis lists

View File

@@ -1,11 +1,11 @@
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from typing import Any
from sqlalchemy.orm import Session
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.task_utils import name_cc_prune_task
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
@@ -16,20 +16,14 @@ from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import Document
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.engine import get_db_current_time
from danswer.db.enums import TaskStatus
from danswer.db.models import Connector
from danswer.db.models import Credential
from danswer.db.models import TaskQueueState
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.db.tasks import get_latest_task_by_type
from danswer.redis.redis_pool import RedisPool
from danswer.redis.redis_pool import get_redis_client
from danswer.server.documents.models import DeletionAttemptSnapshot
from danswer.utils.logger import setup_logger
logger = setup_logger()
redis_pool = RedisPool()
def _get_deletion_status(
@@ -46,7 +40,7 @@ def _get_deletion_status(
rcd = RedisConnectorDeletion(cc_pair.id)
r = redis_pool.get_client()
r = get_redis_client()
if not r.exists(rcd.fence_key):
return None
@@ -69,53 +63,19 @@ def get_deletion_attempt_snapshot(
)
def should_prune_cc_pair(
connector: Connector, credential: Credential, db_session: Session
) -> bool:
if not connector.prune_freq:
return False
pruning_task_name = name_cc_prune_task(
connector_id=connector.id, credential_id=credential.id
)
last_pruning_task = get_latest_task(pruning_task_name, db_session)
current_db_time = get_db_current_time(db_session)
if not last_pruning_task:
time_since_initialization = current_db_time - connector.time_created
if time_since_initialization.total_seconds() >= connector.prune_freq:
return True
return False
if not ALLOW_SIMULTANEOUS_PRUNING:
pruning_type_task_name = name_cc_prune_task()
last_pruning_type_task = get_latest_task_by_type(
pruning_type_task_name, db_session
)
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
last_pruning_type_task, db_session
):
return False
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
return False
if not last_pruning_task.start_time:
return False
time_since_last_pruning = current_db_time - last_pruning_task.start_time
return time_since_last_pruning.total_seconds() >= connector.prune_freq
def document_batch_to_ids(doc_batch: list[Document]) -> set[str]:
return {doc.id for doc in doc_batch}
def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> set[str]:
def extract_ids_from_runnable_connector(
runnable_connector: BaseConnector,
progress_callback: Callable[[int], None] | None = None,
) -> set[str]:
"""
If the PruneConnector hasnt been implemented for the given connector, just pull
all docs using the load_from_state and grab out the IDs
all docs using the load_from_state and grab out the IDs.
Optionally, a callback can be passed to handle the length of each document batch.
"""
all_connector_doc_ids: set[str] = set()
@@ -138,6 +98,36 @@ def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> se
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
)(document_batch_to_ids)
for doc_batch in doc_batch_generator:
if progress_callback:
progress_callback(len(doc_batch))
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
return all_connector_doc_ids
def celery_is_listening_to_queue(worker: Any, name: str) -> bool:
"""Checks to see if we're listening to the named queue"""
# how to get a list of queues this worker is listening to
# https://stackoverflow.com/questions/29790523/how-to-determine-which-queues-a-celery-worker-is-consuming-at-runtime
queue_names = list(worker.app.amqp.queues.consume_from.keys())
for queue_name in queue_names:
if queue_name == name:
return True
return False
def celery_is_worker_primary(worker: Any) -> bool:
"""There are multiple approaches that could be taken to determine if a celery worker
is 'primary', as defined by us. But the way we do it is to check the hostname set
for the celery worker, which can be done either in celeryconfig.py or on the
command line with '--hostname'."""
hostname = worker.hostname
if hostname.startswith("light"):
return False
if hostname.startswith("heavy"):
return False
return True

View File

@@ -1,7 +1,11 @@
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
import urllib.parse
from danswer.configs.app_configs import CELERY_BROKER_POOL_LIMIT
from danswer.configs.app_configs import CELERY_RESULT_EXPIRES
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY_RESULT_BACKEND
from danswer.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
from danswer.configs.app_configs import REDIS_HOST
from danswer.configs.app_configs import REDIS_PASSWORD
from danswer.configs.app_configs import REDIS_PORT
@@ -9,12 +13,13 @@ from danswer.configs.app_configs import REDIS_SSL
from danswer.configs.app_configs import REDIS_SSL_CA_CERTS
from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
CELERY_SEPARATOR = ":"
CELERY_PASSWORD_PART = ""
if REDIS_PASSWORD:
CELERY_PASSWORD_PART = f":{REDIS_PASSWORD}@"
CELERY_PASSWORD_PART = ":" + urllib.parse.quote(REDIS_PASSWORD, safe="") + "@"
REDIS_SCHEME = "redis"
@@ -36,12 +41,30 @@ result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PO
# can stall other tasks.
worker_prefetch_multiplier = 4
broker_connection_retry_on_startup = True
broker_pool_limit = CELERY_BROKER_POOL_LIMIT
# redis broker settings
# https://docs.celeryq.dev/projects/kombu/en/stable/reference/kombu.transport.redis.html
broker_transport_options = {
"priority_steps": list(range(len(DanswerCeleryPriority))),
"sep": CELERY_SEPARATOR,
"queue_order_strategy": "priority",
"retry_on_timeout": True,
"health_check_interval": REDIS_HEALTH_CHECK_INTERVAL,
"socket_keepalive": True,
"socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS,
}
# redis backend settings
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
# there doesn't appear to be a way to set socket_keepalive_options on the redis result backend
redis_socket_keepalive = True
redis_retry_on_timeout = True
redis_backend_health_check_interval = REDIS_HEALTH_CHECK_INTERVAL
task_default_priority = DanswerCeleryPriority.MEDIUM
task_acks_late = True

View File

@@ -0,0 +1,110 @@
import redis
from celery import shared_task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import ObjectDeletedError
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_pool import get_redis_client
@shared_task(
name="check_for_connector_deletion_task",
soft_time_limit=JOB_TIMEOUT,
trail=False,
)
def check_for_connector_deletion_task() -> None:
r = get_redis_client()
lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
with Session(get_sqlalchemy_engine()) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
try_generate_document_cc_pair_cleanup_tasks(
cc_pair, db_session, r, lock_beat
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception")
finally:
if lock_beat.owned():
lock_beat.release()
def try_generate_document_cc_pair_cleanup_tasks(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Note that syncing can still be required even if the number of sync tasks generated is zero.
Returns None if no syncing is required.
"""
lock_beat.reacquire()
rcd = RedisConnectorDeletion(cc_pair.id)
# don't generate sync tasks if tasks are still pending
if r.exists(rcd.fence_key):
return None
# we need to refresh the state of the object inside the fence
# to avoid a race condition with db.commit/fence deletion
# at the end of this taskset
try:
db_session.refresh(cc_pair)
except ObjectDeletedError:
return None
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
return None
# add tasks to celery and build up the task set to monitor in redis
r.delete(rcd.taskset_key)
# Add all documents that need to be updated into the queue
task_logger.info(
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
)
tasks_generated = rcd.generate_tasks(celery_app, db_session, r, lock_beat)
if tasks_generated is None:
return None
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
task_logger.info(
f"RedisConnectorDeletion.generate_tasks finished. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
r.set(rcd.fence_key, tasks_generated)
return tasks_generated

View File

@@ -0,0 +1,137 @@
#####
# Periodic Tasks
#####
import json
from typing import Any
from celery import shared_task
from celery.contrib.abortable import AbortableTask # type: ignore
from celery.exceptions import TaskRevokedError
from sqlalchemy import inspect
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import PostgresAdvisoryLocks
from danswer.db.engine import get_sqlalchemy_engine # type: ignore
@shared_task(
name="kombu_message_cleanup_task",
soft_time_limit=JOB_TIMEOUT,
bind=True,
base=AbortableTask,
)
def kombu_message_cleanup_task(self: Any) -> int:
"""Runs periodically to clean up the kombu_message table"""
# we will select messages older than this amount to clean up
KOMBU_MESSAGE_CLEANUP_AGE = 7 # days
KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT = 1000
ctx = {}
ctx["last_processed_id"] = 0
ctx["deleted"] = 0
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
with Session(get_sqlalchemy_engine()) as db_session:
# Exit the task if we can't take the advisory lock
result = db_session.execute(
text("SELECT pg_try_advisory_lock(:id)"),
{"id": PostgresAdvisoryLocks.KOMBU_MESSAGE_CLEANUP_LOCK_ID.value},
).scalar()
if not result:
return 0
while True:
if self.is_aborted():
raise TaskRevokedError("kombu_message_cleanup_task was aborted.")
b = kombu_message_cleanup_task_helper(ctx, db_session)
if not b:
break
db_session.commit()
if ctx["deleted"] > 0:
task_logger.info(
f"Deleted {ctx['deleted']} orphaned messages from kombu_message."
)
return ctx["deleted"]
def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool:
"""
Helper function to clean up old messages from the `kombu_message` table that are no longer relevant.
This function retrieves messages from the `kombu_message` table that are no longer visible and
older than a specified interval. It checks if the corresponding task_id exists in the
`celery_taskmeta` table. If the task_id does not exist, the message is deleted.
Args:
ctx (dict): A context dictionary containing configuration parameters such as:
- 'cleanup_age' (int): The age in days after which messages are considered old.
- 'page_limit' (int): The maximum number of messages to process in one batch.
- 'last_processed_id' (int): The ID of the last processed message to handle pagination.
- 'deleted' (int): A counter to track the number of deleted messages.
db_session (Session): The SQLAlchemy database session for executing queries.
Returns:
bool: Returns True if there are more rows to process, False if not.
"""
inspector = inspect(db_session.bind)
if not inspector:
return False
# With the move to redis as celery's broker and backend, kombu tables may not even exist.
# We can fail silently.
if not inspector.has_table("kombu_message"):
return False
query = text(
"""
SELECT id, timestamp, payload
FROM kombu_message WHERE visible = 'false'
AND timestamp < CURRENT_TIMESTAMP - INTERVAL :interval_days
AND id > :last_processed_id
ORDER BY id
LIMIT :page_limit
"""
)
kombu_messages = db_session.execute(
query,
{
"interval_days": f"{ctx['cleanup_age']} days",
"page_limit": ctx["page_limit"],
"last_processed_id": ctx["last_processed_id"],
},
).fetchall()
if len(kombu_messages) == 0:
return False
for msg in kombu_messages:
payload = json.loads(msg[2])
task_id = payload["headers"]["id"]
# Check if task_id exists in celery_taskmeta
task_exists = db_session.execute(
text("SELECT 1 FROM celery_taskmeta WHERE task_id = :task_id"),
{"task_id": task_id},
).fetchone()
# If task_id does not exist, delete the message
if not task_exists:
result = db_session.execute(
text("DELETE FROM kombu_message WHERE id = :message_id"),
{"message_id": msg[0]},
)
if result.rowcount > 0: # type: ignore
ctx["deleted"] += 1
ctx["last_processed_id"] = msg[0]
return True

View File

@@ -0,0 +1,239 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from uuid import uuid4
import redis
from celery import shared_task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import InputType
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_pool import get_redis_client
@shared_task(
name="check_for_prune_task_2",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_prune_task_2() -> None:
r = get_redis_client()
lock_beat = r.lock(
DanswerRedisLocks.CHECK_PRUNE_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
with Session(get_sqlalchemy_engine()) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
tasks_created = ccpair_pruning_generator_task_creation_helper(
cc_pair, db_session, r, lock_beat
)
if not tasks_created:
continue
task_logger.info(f"Pruning started: cc_pair_id={cc_pair.id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception")
finally:
if lock_beat.owned():
lock_beat.release()
def ccpair_pruning_generator_task_creation_helper(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
) -> int | None:
"""Returns an int if pruning is triggered.
The int represents the number of prune tasks generated (in this case, only one
because the task is a long running generator task.)
Returns None if no pruning is triggered (due to not being needed or
other reasons such as simultaneous pruning restrictions.
Checks for scheduling related conditions, then delegates the rest of the checks to
try_creating_prune_generator_task.
"""
lock_beat.reacquire()
# skip pruning if no prune frequency is set
# pruning can still be forced via the API which will run a pruning task directly
if not cc_pair.connector.prune_freq:
return None
# skip pruning if the next scheduled prune time hasn't been reached yet
last_pruned = cc_pair.last_pruned
if not last_pruned:
# if never pruned, use the connector time created as the last_pruned time
last_pruned = cc_pair.connector.time_created
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
if datetime.now(timezone.utc) < next_prune:
return None
return try_creating_prune_generator_task(cc_pair, db_session, r)
def try_creating_prune_generator_task(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
) -> int | None:
"""Checks for any conditions that should block the pruning generator task from being
created, then creates the task.
Does not check for scheduling related conditions as this function
is used to trigger prunes immediately.
"""
if not ALLOW_SIMULTANEOUS_PRUNING:
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
return None
rcp = RedisConnectorPruning(cc_pair.id)
# skip pruning if already pruning
if r.exists(rcp.fence_key):
return None
# skip pruning if the cc_pair is deleting
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# add a long running generator task to the queue
r.delete(rcp.generator_complete_key)
r.delete(rcp.taskset_key)
custom_task_id = f"{rcp.generator_task_id_prefix}_{uuid4()}"
celery_app.send_task(
"connector_pruning_generator_task",
kwargs=dict(
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
),
queue=DanswerCeleryQueues.CONNECTOR_PRUNING,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
# set this only after all tasks have been added
r.set(rcp.fence_key, 1)
return 1
@shared_task(name="connector_pruning_generator_task", soft_time_limit=JOB_TIMEOUT)
def connector_pruning_generator_task(connector_id: int, credential_id: int) -> None:
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
r = get_redis_client()
with Session(get_sqlalchemy_engine()) as db_session:
try:
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair:
task_logger.warning(
f"ccpair not found for {connector_id} {credential_id}"
)
return
rcp = RedisConnectorPruning(cc_pair.id)
# Define the callback function
def redis_increment_callback(amount: int) -> None:
r.incrby(rcp.generator_progress_key, amount)
runnable_connector = instantiate_connector(
db_session,
cc_pair.connector.source,
InputType.PRUNE,
cc_pair.connector.connector_specific_config,
cc_pair.credential,
)
# a list of docs in the source
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
runnable_connector, redis_increment_callback
)
# a list of docs in our local index
all_indexed_document_ids = {
doc.id
for doc in get_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
}
# generate list of docs to remove (no longer in the source)
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
task_logger.info(
f"Pruning set collected: "
f"cc_pair_id={cc_pair.id} "
f"docs_to_remove={len(doc_ids_to_remove)} "
f"doc_source={cc_pair.connector.source}"
)
rcp.documents_to_prune = set(doc_ids_to_remove)
task_logger.info(
f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}"
)
tasks_generated = rcp.generate_tasks(celery_app, db_session, r, None)
if tasks_generated is None:
return None
task_logger.info(
f"RedisConnectorPruning.generate_tasks finished. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
)
r.set(rcp.generator_complete_key, tasks_generated)
except Exception as e:
task_logger.exception(
f"Failed to run pruning for connector id {connector_id}."
)
r.delete(rcp.generator_progress_key)
r.delete(rcp.taskset_key)
r.delete(rcp.fence_key)
raise e

View File

@@ -0,0 +1,123 @@
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_document
from danswer.background.celery.celery_app import task_logger
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document
from danswer.db.document import get_document_connector_count
from danswer.db.document import mark_document_as_synced
from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.engine import get_sqlalchemy_engine
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
@shared_task(
name="document_by_cc_pair_cleanup_task",
bind=True,
soft_time_limit=45,
time_limit=60,
max_retries=3,
)
def document_by_cc_pair_cleanup_task(
self: Task, document_id: str, connector_id: int, credential_id: int
) -> bool:
"""A lightweight subtask used to clean up document to cc pair relationships.
Created by connection deletion and connector pruning parent tasks."""
"""
To delete a connector / credential pair:
(1) find all documents associated with connector / credential pair where there
this the is only connector / credential pair that has indexed it
(2) delete all documents from document stores
(3) delete all entries from postgres
(4) find all documents associated with connector / credential pair where there
are multiple connector / credential pairs that have indexed it
(5) update document store entries to remove access associated with the
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
try:
with Session(get_sqlalchemy_engine()) as db_session:
action = "skip"
chunks_affected = 0
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
count = get_document_connector_count(db_session, document_id)
if count == 1:
# count == 1 means this is the only remaining cc_pair reference to the doc
# delete it from vespa and the db
action = "delete"
chunks_affected = document_index.delete_single(document_id)
delete_documents_complete__no_commit(
db_session=db_session,
document_ids=[document_id],
)
elif count > 1:
action = "update"
# count > 1 means the document still has cc_pair references
doc = get_document(document_id, db_session)
if not doc:
return False
# the below functions do not include cc_pairs being deleted.
# i.e. they will correctly omit access for the current cc_pair
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
doc_sets = fetch_document_sets_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
fields = VespaDocumentFields(
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
chunks_affected = document_index.update_single(
document_id, fields=fields
)
# there are still other cc_pair references to the doc, so just resync to Vespa
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_id=document_id,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
mark_document_as_synced(document_id, db_session)
else:
pass
task_logger.info(
f"document_id={document_id} refcount={count} action={action} chunks={chunks_affected}"
)
db_session.commit()
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
except Exception as e:
task_logger.exception("Unexpected exception")
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown)
return True

View File

@@ -0,0 +1,580 @@
import traceback
from typing import cast
import redis
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_document
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector import mark_ccpair_as_pruned
from danswer.db.connector_credential_pair import add_deletion_failure_message
from danswer.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit,
)
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.document import count_documents_by_needs_sync
from danswer.db.document import get_document
from danswer.db.document import mark_document_as_synced
from danswer.db.document_set import delete_document_set
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from danswer.db.document_set import fetch_document_sets
from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.models import DocumentSet
from danswer.db.models import UserGroup
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import UpdateRequest
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import noop_fallback
# celery auto associates tasks created inside another task,
# which bloats the result metadata considerably. trail=False prevents this.
@shared_task(
name="check_for_vespa_sync_task",
soft_time_limit=JOB_TIMEOUT,
trail=False,
)
def check_for_vespa_sync_task() -> None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
r = get_redis_client()
lock_beat = r.lock(
DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
with Session(get_sqlalchemy_engine()) as db_session:
try_generate_stale_document_sync_tasks(db_session, r, lock_beat)
# check if any document sets are not synced
document_set_info = fetch_document_sets(
user_id=None, db_session=db_session, include_outdated=True
)
for document_set, _ in document_set_info:
try_generate_document_set_sync_tasks(
document_set, db_session, r, lock_beat
)
# check if any user groups are not synced
if global_version.is_ee_version():
try:
fetch_user_groups = fetch_versioned_implementation(
"danswer.db.user_group", "fetch_user_groups"
)
user_groups = fetch_user_groups(
db_session=db_session, only_up_to_date=False
)
for usergroup in user_groups:
try_generate_user_group_sync_tasks(
usergroup, db_session, r, lock_beat
)
except ModuleNotFoundError:
# Always exceptions on the MIT version, which is expected
# We shouldn't actually get here if the ee version check works
pass
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception")
finally:
if lock_beat.owned():
lock_beat.release()
def try_generate_stale_document_sync_tasks(
db_session: Session, r: Redis, lock_beat: redis.lock.Lock
) -> int | None:
# the fence is up, do nothing
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
return None
r.delete(RedisConnectorCredentialPair.get_taskset_key()) # delete the taskset
# add tasks to celery and build up the task set to monitor in redis
stale_doc_count = count_documents_by_needs_sync(db_session)
if stale_doc_count == 0:
return None
task_logger.info(
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
)
task_logger.info("RedisConnector.generate_tasks starting by cc_pair.")
# rkuo: we could technically sync all stale docs in one big pass.
# but I feel it's more understandable to group the docs by cc_pair
total_tasks_generated = 0
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
rc = RedisConnectorCredentialPair(cc_pair.id)
tasks_generated = rc.generate_tasks(celery_app, db_session, r, lock_beat)
if tasks_generated is None:
continue
if tasks_generated == 0:
continue
task_logger.info(
f"RedisConnector.generate_tasks finished for single cc_pair. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
)
total_tasks_generated += tasks_generated
task_logger.info(
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
)
r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated)
return total_tasks_generated
def try_generate_document_set_sync_tasks(
document_set: DocumentSet, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
) -> int | None:
lock_beat.reacquire()
rds = RedisDocumentSet(document_set.id)
# don't generate document set sync tasks if tasks are still pending
if r.exists(rds.fence_key):
return None
# don't generate sync tasks if we're up to date
# race condition with the monitor/cleanup function if we use a cached result!
db_session.refresh(document_set)
if document_set.is_up_to_date:
return None
# add tasks to celery and build up the task set to monitor in redis
r.delete(rds.taskset_key)
task_logger.info(
f"RedisDocumentSet.generate_tasks starting. document_set_id={document_set.id}"
)
# Add all documents that need to be updated into the queue
tasks_generated = rds.generate_tasks(celery_app, db_session, r, lock_beat)
if tasks_generated is None:
return None
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
task_logger.info(
f"RedisDocumentSet.generate_tasks finished. "
f"document_set_id={document_set.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
r.set(rds.fence_key, tasks_generated)
return tasks_generated
def try_generate_user_group_sync_tasks(
usergroup: UserGroup, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
) -> int | None:
lock_beat.reacquire()
rug = RedisUserGroup(usergroup.id)
# don't generate sync tasks if tasks are still pending
if r.exists(rug.fence_key):
return None
# race condition with the monitor/cleanup function if we use a cached result!
db_session.refresh(usergroup)
if usergroup.is_up_to_date:
return None
# add tasks to celery and build up the task set to monitor in redis
r.delete(rug.taskset_key)
# Add all documents that need to be updated into the queue
task_logger.info(
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
)
tasks_generated = rug.generate_tasks(celery_app, db_session, r, lock_beat)
if tasks_generated is None:
return None
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
task_logger.info(
f"RedisUserGroup.generate_tasks finished. "
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
r.set(rug.fence_key, tasks_generated)
return tasks_generated
def monitor_connector_taskset(r: Redis) -> None:
fence_value = r.get(RedisConnectorCredentialPair.get_fence_key())
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = r.scard(RedisConnectorCredentialPair.get_taskset_key())
task_logger.info(
f"Stale document sync progress: remaining={count} initial={initial_count}"
)
if count == 0:
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
task_logger.info(f"Successfully synced stale documents. count={initial_count}")
def monitor_document_set_taskset(
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key)
if document_set_id is None:
task_logger.warning(f"could not parse document set id from {fence_key}")
return
rds = RedisDocumentSet(document_set_id)
fence_value = r.get(rds.fence_key)
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rds.taskset_key))
task_logger.info(
f"Document set sync progress: document_set_id={document_set_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
document_set = cast(
DocumentSet,
get_document_set_by_id(db_session=db_session, document_set_id=document_set_id),
) # casting since we "know" a document set with this ID exists
if document_set:
if not document_set.connector_credential_pairs:
# if there are no connectors, then delete the document set.
delete_document_set(document_set_row=document_set, db_session=db_session)
task_logger.info(
f"Successfully deleted document set with ID: '{document_set_id}'!"
)
else:
mark_document_set_as_synced(document_set_id, db_session)
task_logger.info(
f"Successfully synced document set with ID: '{document_set_id}'!"
)
r.delete(rds.taskset_key)
r.delete(rds.fence_key)
def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
if cc_pair_id is None:
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
return
rcd = RedisConnectorDeletion(cc_pair_id)
fence_value = r.get(rcd.fence_key)
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rcd.taskset_key))
task_logger.info(
f"Connector deletion progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
with Session(get_sqlalchemy_engine()) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if not cc_pair:
task_logger.warning(
f"monitor_connector_deletion_taskset - cc_pair_id not found: cc_pair_id={cc_pair_id}"
)
return
try:
# clean up the rest of the related Postgres entities
# index attempts
delete_index_attempts(
db_session=db_session,
cc_pair_id=cc_pair.id,
)
# document sets
delete_document_set_cc_pair_relationship__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# user groups
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
"danswer.db.user_group",
"delete_user_group_cc_pair_relationship__no_commit",
noop_fallback,
)
cleanup_user_groups(
cc_pair_id=cc_pair.id,
db_session=db_session,
)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# if there are no credentials left, delete the connector
connector = fetch_connector_by_id(
db_session=db_session,
connector_id=cc_pair.connector_id,
)
if not connector or not len(connector.credentials):
task_logger.info(
"Found no credentials left for connector, deleting connector"
)
db_session.delete(connector)
db_session.commit()
except Exception as e:
stack_trace = traceback.format_exc()
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
add_deletion_failure_message(db_session, cc_pair.id, error_message)
task_logger.exception(
f"Failed to run connector_deletion. "
f"cc_pair_id={cc_pair_id} connector_id={cc_pair.connector_id} credential_id={cc_pair.credential_id}"
)
raise e
task_logger.info(
f"Successfully deleted cc_pair: "
f"cc_pair_id={cc_pair_id} "
f"connector_id={cc_pair.connector_id} "
f"credential_id={cc_pair.credential_id} "
f"docs_deleted={initial_count}"
)
r.delete(rcd.taskset_key)
r.delete(rcd.fence_key)
def monitor_ccpair_pruning_taskset(
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id = RedisConnectorPruning.get_id_from_fence_key(fence_key)
if cc_pair_id is None:
task_logger.warning(
f"monitor_connector_pruning_taskset: could not parse cc_pair_id from {fence_key}"
)
return
rcp = RedisConnectorPruning(cc_pair_id)
fence_value = r.get(rcp.fence_key)
if fence_value is None:
return
generator_value = r.get(rcp.generator_complete_key)
if generator_value is None:
return
try:
initial_count = int(cast(int, generator_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rcp.taskset_key))
task_logger.info(
f"Connector pruning progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
mark_ccpair_as_pruned(cc_pair_id, db_session)
task_logger.info(
f"Successfully pruned connector credential pair. cc_pair_id={cc_pair_id}"
)
r.delete(rcp.taskset_key)
r.delete(rcp.generator_progress_key)
r.delete(rcp.generator_complete_key)
r.delete(rcp.fence_key)
@shared_task(name="monitor_vespa_sync", soft_time_limit=300)
def monitor_vespa_sync() -> None:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
It scans for fence values and then gets the counts of any associated tasksets.
If the count is 0, that means all tasks finished and we should clean up.
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
do anything too expensive in this function!
"""
r = get_redis_client()
lock_beat = r.lock(
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# prevent overlapping tasks
if not lock_beat.acquire(blocking=False):
return
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
monitor_connector_deletion_taskset(key_bytes, r)
with Session(get_sqlalchemy_engine()) as db_session:
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
monitor_document_set_taskset(key_bytes, r, db_session)
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
monitor_usergroup_taskset = (
fetch_versioned_implementation_with_fallback(
"danswer.background.celery.tasks.vespa.tasks",
"monitor_usergroup_taskset",
noop_fallback,
)
)
monitor_usergroup_taskset(key_bytes, r, db_session)
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
# uncomment for debugging if needed
# r_celery = celery_app.broker_connection().channel().client
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
# task_logger.warning(f"queue={DanswerCeleryQueues.VESPA_METADATA_SYNC} length={length}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
finally:
if lock_beat.owned():
lock_beat.release()
@shared_task(
name="vespa_metadata_sync_task",
bind=True,
soft_time_limit=45,
time_limit=60,
max_retries=3,
)
def vespa_metadata_sync_task(self: Task, document_id: str) -> bool:
task_logger.info(f"document_id={document_id}")
try:
with Session(get_sqlalchemy_engine()) as db_session:
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
doc = get_document(document_id, db_session)
if not doc:
return False
# document set sync
doc_sets = fetch_document_sets_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
# User group sync
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
update_request = UpdateRequest(
document_ids=[document_id],
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# update Vespa
document_index.update(update_requests=[update_request])
# update db last. Worst case = we crash right before this and
# the sync might repeat again later
mark_document_as_synced(document_id, db_session)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
except Exception as e:
task_logger.exception("Unexpected exception")
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown)
return True

View File

@@ -1,110 +0,0 @@
"""
To delete a connector / credential pair:
(1) find all documents associated with connector / credential pair where there
this the is only connector / credential pair that has indexed it
(2) delete all documents from document stores
(3) delete all entries from postgres
(4) find all documents associated with connector / credential pair where there
are multiple connector / credential pairs that have indexed it
(5) update document store entries to remove access associated with the
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_documents
from danswer.db.document import delete_documents_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document_connector_counts
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.engine import get_sqlalchemy_engine
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import UpdateRequest
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.logger import setup_logger
logger = setup_logger()
_DELETION_BATCH_SIZE = 1000
def delete_connector_credential_pair_batch(
document_ids: list[str],
connector_id: int,
credential_id: int,
document_index: DocumentIndex,
) -> None:
"""
Removes a batch of documents ids from a cc-pair. If no other cc-pair uses a document anymore
it gets permanently deleted.
"""
with Session(get_sqlalchemy_engine()) as db_session:
# acquire lock for all documents in this batch so that indexing can't
# override the deletion
with prepare_to_modify_documents(
db_session=db_session, document_ids=document_ids
):
document_connector_counts = get_document_connector_counts(
db_session=db_session, document_ids=document_ids
)
# figure out which docs need to be completely deleted
document_ids_to_delete = [
document_id
for document_id, cnt in document_connector_counts
if cnt == 1
]
logger.debug(f"Deleting documents: {document_ids_to_delete}")
document_index.delete(doc_ids=document_ids_to_delete)
delete_documents_complete__no_commit(
db_session=db_session,
document_ids=document_ids_to_delete,
)
# figure out which docs need to be updated
document_ids_to_update = [
document_id for document_id, cnt in document_connector_counts if cnt > 1
]
# maps document id to list of document set names
new_doc_sets_for_documents: dict[str, set[str]] = {
document_id_and_document_set_names_tuple[0]: set(
document_id_and_document_set_names_tuple[1]
)
for document_id_and_document_set_names_tuple in fetch_document_sets_for_documents(
db_session=db_session,
document_ids=document_ids_to_update,
)
}
# determine future ACLs for documents in batch
access_for_documents = get_access_for_documents(
document_ids=document_ids_to_update,
db_session=db_session,
)
# update Vespa
logger.debug(f"Updating documents: {document_ids_to_update}")
update_requests = [
UpdateRequest(
document_ids=[document_id],
access=access,
document_sets=new_doc_sets_for_documents[document_id],
)
for document_id, access in access_for_documents.items()
]
document_index.update(update_requests=update_requests)
# clean up Postgres
delete_documents_by_connector_credential_pair__no_commit(
db_session=db_session,
document_ids=document_ids_to_update,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
db_session.commit()

View File

@@ -14,6 +14,7 @@ from danswer.configs.app_configs import POLL_CONNECTOR_OFFSET
from danswer.connectors.connector_runner import ConnectorRunner
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import IndexAttemptMetadata
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.engine import get_sqlalchemy_engine
@@ -29,6 +30,7 @@ from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import IndexAttemptSingleton
from danswer.utils.logger import setup_logger
@@ -48,7 +50,7 @@ def _get_connector_runner(
"""
NOTE: `start_time` and `end_time` are only used for poll connectors
Returns an interator of document batches and whether the returned documents
Returns an iterator of document batches and whether the returned documents
are the complete list of existing documents of the connector. If the task
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
"""
@@ -66,12 +68,17 @@ def _get_connector_runner(
logger.exception(f"Unable to instantiate connector due to {e}")
# since we failed to even instantiate the connector, we pause the CCPair since
# it will never succeed
update_connector_credential_pair(
db_session=db_session,
connector_id=attempt.connector_credential_pair.connector.id,
credential_id=attempt.connector_credential_pair.credential.id,
status=ConnectorCredentialPairStatus.PAUSED,
cc_pair = get_connector_credential_pair_from_id(
attempt.connector_credential_pair.id, db_session
)
if cc_pair and cc_pair.status == ConnectorCredentialPairStatus.ACTIVE:
update_connector_credential_pair(
db_session=db_session,
connector_id=attempt.connector_credential_pair.connector.id,
credential_id=attempt.connector_credential_pair.credential.id,
status=ConnectorCredentialPairStatus.PAUSED,
)
raise e
return ConnectorRunner(
@@ -103,15 +110,24 @@ def _run_indexing(
)
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings
search_settings=search_settings,
heartbeat=IndexingHeartbeat(
index_attempt_id=index_attempt.id,
db_session=db_session,
# let the world know we're still making progress after
# every 10 batches
freq=10,
),
)
indexing_pipeline = build_indexing_pipeline(
attempt_id=index_attempt.id,
embedder=embedding_model,
document_index=document_index,
ignore_time_skip=index_attempt.from_beginning
or (search_settings.status == IndexModelStatus.FUTURE),
ignore_time_skip=(
index_attempt.from_beginning
or (search_settings.status == IndexModelStatus.FUTURE)
),
db_session=db_session,
)

View File

@@ -23,7 +23,7 @@ from danswer.db.connector import fetch_connectors
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import init_sqlalchemy_engine
from danswer.db.engine import SqlEngine
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_inprogress_index_attempts
@@ -96,14 +96,20 @@ def _should_create_new_indexing(
if last_index.status == IndexingStatus.IN_PROGRESS:
return False
else:
if connector.id == 0: # Ingestion API
if (
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
): # Ingestion API
return False
return True
# If the connector is paused or is the ingestion API, don't index
# NOTE: during an embedding model switch over, the following logic
# is bypassed by the above check for a future model
if not cc_pair.status.is_active() or connector.id == 0:
if (
not cc_pair.status.is_active()
or connector.id == 0
or connector.source == DocumentSource.INGESTION_API
):
return False
if not last_index:
@@ -347,7 +353,7 @@ def kickoff_indexing_jobs(
run_indexing_entrypoint,
attempt.id,
attempt.connector_credential_pair_id,
global_version.get_is_ee_version(),
global_version.is_ee_version(),
pure=False,
)
if not run:
@@ -358,7 +364,7 @@ def kickoff_indexing_jobs(
run_indexing_entrypoint,
attempt.id,
attempt.connector_credential_pair_id,
global_version.get_is_ee_version(),
global_version.is_ee_version(),
pure=False,
)
if not run:
@@ -416,6 +422,7 @@ def update_loop(
warm_up_bi_encoder(
embedding_model=embedding_model,
)
logger.notice("First inference complete.")
client_primary: Client | SimpleJobClient
client_secondary: Client | SimpleJobClient
@@ -444,6 +451,7 @@ def update_loop(
existing_jobs: dict[int, Future | SimpleJob] = {}
logger.notice("Startup complete. Waiting for indexing jobs...")
while True:
start = time.time()
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
@@ -475,7 +483,9 @@ def update_loop(
def update__main() -> None:
set_is_ee_based_on_env_variable()
init_sqlalchemy_engine(POSTGRES_INDEXER_APP_NAME)
# initialize the Postgres connection pool
SqlEngine.set_app_name(POSTGRES_INDEXER_APP_NAME)
logger.notice("Starting indexing service")
update_loop()

View File

@@ -6,7 +6,6 @@ from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import PERSONAS_YAML
from danswer.configs.chat_configs import PROMPTS_YAML
from danswer.db.document_set import get_or_create_document_set_by_name
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.db.models import Persona
@@ -18,30 +17,32 @@ from danswer.db.persona import upsert_prompt
from danswer.search.enums import RecencyBiasSetting
def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
def load_prompts_from_yaml(
db_session: Session, prompts_yaml: str = PROMPTS_YAML
) -> None:
with open(prompts_yaml, "r") as file:
data = yaml.safe_load(file)
all_prompts = data.get("prompts", [])
with Session(get_sqlalchemy_engine()) as db_session:
for prompt in all_prompts:
upsert_prompt(
user=None,
prompt_id=prompt.get("id"),
name=prompt["name"],
description=prompt["description"].strip(),
system_prompt=prompt["system"].strip(),
task_prompt=prompt["task"].strip(),
include_citations=prompt["include_citations"],
datetime_aware=prompt.get("datetime_aware", True),
default_prompt=True,
personas=None,
db_session=db_session,
commit=True,
)
for prompt in all_prompts:
upsert_prompt(
user=None,
prompt_id=prompt.get("id"),
name=prompt["name"],
description=prompt["description"].strip(),
system_prompt=prompt["system"].strip(),
task_prompt=prompt["task"].strip(),
include_citations=prompt["include_citations"],
datetime_aware=prompt.get("datetime_aware", True),
default_prompt=True,
personas=None,
db_session=db_session,
commit=True,
)
def load_personas_from_yaml(
db_session: Session,
personas_yaml: str = PERSONAS_YAML,
default_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
) -> None:
@@ -49,117 +50,117 @@ def load_personas_from_yaml(
data = yaml.safe_load(file)
all_personas = data.get("personas", [])
with Session(get_sqlalchemy_engine()) as db_session:
for persona in all_personas:
doc_set_names = persona["document_sets"]
doc_sets: list[DocumentSetDBModel] = [
get_or_create_document_set_by_name(db_session, name)
for name in doc_set_names
for persona in all_personas:
doc_set_names = persona["document_sets"]
doc_sets: list[DocumentSetDBModel] = [
get_or_create_document_set_by_name(db_session, name)
for name in doc_set_names
]
# Assume if user hasn't set any document sets for the persona, the user may want
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
# the document sets for the persona
doc_set_ids: list[int] | None = None
if doc_sets:
doc_set_ids = [doc_set.id for doc_set in doc_sets]
else:
doc_set_ids = None
prompt_ids: list[int] | None = None
prompt_set_names = persona["prompts"]
if prompt_set_names:
prompts: list[PromptDBModel | None] = [
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
for prompt_name in prompt_set_names
]
if any([prompt is None for prompt in prompts]):
raise ValueError("Invalid Persona configs, not all prompts exist")
# Assume if user hasn't set any document sets for the persona, the user may want
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
# the document sets for the persona
doc_set_ids: list[int] | None = None
if doc_sets:
doc_set_ids = [doc_set.id for doc_set in doc_sets]
else:
doc_set_ids = None
if prompts:
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
prompt_ids: list[int] | None = None
prompt_set_names = persona["prompts"]
if prompt_set_names:
prompts: list[PromptDBModel | None] = [
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
for prompt_name in prompt_set_names
]
if any([prompt is None for prompt in prompts]):
raise ValueError("Invalid Persona configs, not all prompts exist")
if prompts:
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
p_id = persona.get("id")
tool_ids = []
if persona.get("image_generation"):
image_gen_tool = (
db_session.query(ToolDBModel)
.filter(ToolDBModel.name == "ImageGenerationTool")
.first()
)
if image_gen_tool:
tool_ids.append(image_gen_tool.id)
llm_model_provider_override = persona.get("llm_model_provider_override")
llm_model_version_override = persona.get("llm_model_version_override")
# Set specific overrides for image generation persona
if persona.get("image_generation"):
llm_model_version_override = "gpt-4o"
existing_persona = (
db_session.query(Persona)
.filter(Persona.name == persona["name"])
p_id = persona.get("id")
tool_ids = []
if persona.get("image_generation"):
image_gen_tool = (
db_session.query(ToolDBModel)
.filter(ToolDBModel.name == "ImageGenerationTool")
.first()
)
if image_gen_tool:
tool_ids.append(image_gen_tool.id)
upsert_persona(
user=None,
persona_id=(-1 * p_id) if p_id is not None else None,
name=persona["name"],
description=persona["description"],
num_chunks=persona.get("num_chunks")
if persona.get("num_chunks") is not None
else default_chunks,
llm_relevance_filter=persona.get("llm_relevance_filter"),
starter_messages=persona.get("starter_messages"),
llm_filter_extraction=persona.get("llm_filter_extraction"),
icon_shape=persona.get("icon_shape"),
icon_color=persona.get("icon_color"),
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
prompt_ids=prompt_ids,
document_set_ids=doc_set_ids,
tool_ids=tool_ids,
builtin_persona=True,
is_public=True,
display_priority=existing_persona.display_priority
if existing_persona is not None
else persona.get("display_priority"),
is_visible=existing_persona.is_visible
if existing_persona is not None
else persona.get("is_visible"),
db_session=db_session,
)
llm_model_provider_override = persona.get("llm_model_provider_override")
llm_model_version_override = persona.get("llm_model_version_override")
# Set specific overrides for image generation persona
if persona.get("image_generation"):
llm_model_version_override = "gpt-4o"
existing_persona = (
db_session.query(Persona).filter(Persona.name == persona["name"]).first()
)
upsert_persona(
user=None,
persona_id=(-1 * p_id) if p_id is not None else None,
name=persona["name"],
description=persona["description"],
num_chunks=persona.get("num_chunks")
if persona.get("num_chunks") is not None
else default_chunks,
llm_relevance_filter=persona.get("llm_relevance_filter"),
starter_messages=persona.get("starter_messages"),
llm_filter_extraction=persona.get("llm_filter_extraction"),
icon_shape=persona.get("icon_shape"),
icon_color=persona.get("icon_color"),
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
prompt_ids=prompt_ids,
document_set_ids=doc_set_ids,
tool_ids=tool_ids,
builtin_persona=True,
is_public=True,
display_priority=existing_persona.display_priority
if existing_persona is not None
else persona.get("display_priority"),
is_visible=existing_persona.is_visible
if existing_persona is not None
else persona.get("is_visible"),
db_session=db_session,
)
def load_input_prompts_from_yaml(input_prompts_yaml: str = INPUT_PROMPT_YAML) -> None:
def load_input_prompts_from_yaml(
db_session: Session, input_prompts_yaml: str = INPUT_PROMPT_YAML
) -> None:
with open(input_prompts_yaml, "r") as file:
data = yaml.safe_load(file)
all_input_prompts = data.get("input_prompts", [])
with Session(get_sqlalchemy_engine()) as db_session:
for input_prompt in all_input_prompts:
# If these prompts are deleted (which is a hard delete in the DB), on server startup
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
insert_input_prompt_if_not_exists(
user=None,
input_prompt_id=input_prompt.get("id"),
prompt=input_prompt["prompt"],
content=input_prompt["content"],
is_public=input_prompt["is_public"],
active=input_prompt.get("active", True),
db_session=db_session,
commit=True,
)
for input_prompt in all_input_prompts:
# If these prompts are deleted (which is a hard delete in the DB), on server startup
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
insert_input_prompt_if_not_exists(
user=None,
input_prompt_id=input_prompt.get("id"),
prompt=input_prompt["prompt"],
content=input_prompt["content"],
is_public=input_prompt["is_public"],
active=input_prompt.get("active", True),
db_session=db_session,
commit=True,
)
def load_chat_yamls(
db_session: Session,
prompt_yaml: str = PROMPTS_YAML,
personas_yaml: str = PERSONAS_YAML,
input_prompts_yaml: str = INPUT_PROMPT_YAML,
) -> None:
load_prompts_from_yaml(prompt_yaml)
load_personas_from_yaml(personas_yaml)
load_input_prompts_from_yaml(input_prompts_yaml)
load_prompts_from_yaml(db_session, prompt_yaml)
load_personas_from_yaml(db_session, personas_yaml)
load_input_prompts_from_yaml(db_session, input_prompts_yaml)

View File

@@ -138,6 +138,12 @@ POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
POSTGRES_API_SERVER_POOL_SIZE = int(
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40
)
POSTGRES_API_SERVER_POOL_OVERFLOW = int(
os.environ.get("POSTGRES_API_SERVER_POOL_OVERFLOW") or 10
)
# defaults to False
POSTGRES_POOL_PRE_PING = os.environ.get("POSTGRES_POOL_PRE_PING", "").lower() == "true"
@@ -164,13 +170,29 @@ REDIS_DB_NUMBER_CELERY_RESULT_BACKEND = int(
)
REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15)) # broker
# will propagate to both our redis client as well as celery's redis client
REDIS_HEALTH_CHECK_INTERVAL = int(os.environ.get("REDIS_HEALTH_CHECK_INTERVAL", 60))
# our redis client only, not celery's
REDIS_POOL_MAX_CONNECTIONS = int(os.environ.get("REDIS_POOL_MAX_CONNECTIONS", 128))
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
# should be one of "required", "optional", or "none"
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "none")
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", "")
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", None)
CELERY_RESULT_EXPIRES = int(os.environ.get("CELERY_RESULT_EXPIRES", 86400)) # seconds
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#broker-pool-limit
# Setting to None may help when there is a proxy in the way closing idle connections
CELERY_BROKER_POOL_LIMIT_DEFAULT = 10
try:
CELERY_BROKER_POOL_LIMIT = int(
os.environ.get("CELERY_BROKER_POOL_LIMIT", CELERY_BROKER_POOL_LIMIT_DEFAULT)
)
except ValueError:
CELERY_BROKER_POOL_LIMIT = CELERY_BROKER_POOL_LIMIT_DEFAULT
#####
# Connector Configs
#####
@@ -247,6 +269,10 @@ JIRA_CONNECTOR_LABELS_TO_SKIP = [
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
if ignored_tag
]
# Maximum size for Jira tickets in bytes (default: 100KB)
JIRA_CONNECTOR_MAX_TICKET_SIZE = int(
os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024)
)
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
@@ -270,7 +296,7 @@ ALLOW_SIMULTANEOUS_PRUNING = (
os.environ.get("ALLOW_SIMULTANEOUS_PRUNING", "").lower() == "true"
)
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
# This is the maximum rate at which documents are queried for a pruning job. 0 disables the limitation.
MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE = int(
os.environ.get("MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE", 0)
)
@@ -334,12 +360,10 @@ INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL", 0))
# exception without aborting the attempt.
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT", 0))
#####
# Miscellaneous
#####
# File based Key Value store no longer used
DYNAMIC_CONFIG_STORE = "PostgresBackedDynamicConfigStore"
JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
# used to allow the background indexing jobs to use a different embedding
# model server than the API server
@@ -388,3 +412,11 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
ENTERPRISE_EDITION_ENABLED = (
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
)
MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "")
DATA_PLANE_SECRET = os.environ.get("DATA_PLANE_SECRET", "")
EXPECTED_API_KEY = os.environ.get("EXPECTED_API_KEY", "")

View File

@@ -1,3 +1,5 @@
import platform
import socket
from enum import auto
from enum import Enum
@@ -34,9 +36,12 @@ POSTGRES_WEB_APP_NAME = "web"
POSTGRES_INDEXER_APP_NAME = "indexer"
POSTGRES_CELERY_APP_NAME = "celery"
POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
POSTGRES_CELERY_WORKER_APP_NAME = "celery_worker"
POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"
POSTGRES_DEFAULT_SCHEMA = "public"
# API Keys
DANSWER_API_KEY_PREFIX = "API_KEY__"
@@ -46,6 +51,7 @@ UNNAMED_KEY_PLACEHOLDER = "Unnamed"
# Key-Value store keys
KV_REINDEX_KEY = "needs_reindexing"
KV_SEARCH_SETTINGS = "search_settings"
KV_UNSTRUCTURED_API_KEY = "unstructured_api_key"
KV_USER_STORE_KEY = "INVITED_USERS"
KV_NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
KV_CRED_KEY = "credential_id_{}"
@@ -62,6 +68,7 @@ KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings"
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
class DocumentSource(str, Enum):
@@ -104,6 +111,7 @@ class DocumentSource(str, Enum):
R2 = "r2"
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
OCI_STORAGE = "oci_storage"
XENFORO = "xenforo"
NOT_APPLICABLE = "not_applicable"
@@ -179,17 +187,17 @@ class PostgresAdvisoryLocks(Enum):
class DanswerCeleryQueues:
VESPA_DOCSET_SYNC_GENERATOR = "vespa_docset_sync_generator"
VESPA_USERGROUP_SYNC_GENERATOR = "vespa_usergroup_sync_generator"
VESPA_METADATA_SYNC = "vespa_metadata_sync"
CONNECTOR_DELETION = "connector_deletion"
CONNECTOR_PRUNING = "connector_pruning"
class DanswerRedisLocks:
PRIMARY_WORKER = "da_lock:primary_worker"
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
MONITOR_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:monitor_connector_deletion_beat"
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
class DanswerCeleryPriority(int, Enum):
@@ -198,3 +206,13 @@ class DanswerCeleryPriority(int, Enum):
MEDIUM = auto()
LOW = auto()
LOWEST = auto()
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3
if platform.system() == "Darwin":
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPALIVE] = 60 # type: ignore
else:
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60 # type: ignore

View File

@@ -194,8 +194,8 @@ class BlobStorageConnector(LoadConnector, PollConnector):
try:
text = extract_file_text(
name,
BytesIO(downloaded_file),
file_name=name,
break_on_unprocessable=False,
)
batch.append(

View File

@@ -0,0 +1,32 @@
import bs4
def build_confluence_document_id(base_url: str, content_url: str) -> str:
"""For confluence, the document id is the page url for a page based document
or the attachment download url for an attachment based document
Args:
base_url (str): The base url of the Confluence instance
content_url (str): The url of the page or attachment download url
Returns:
str: The document id
"""
return f"{base_url}{content_url}"
def get_used_attachments(text: str) -> list[str]:
"""Parse a Confluence html page to generate a list of current
attachment in used
Args:
text (str): The page content
Returns:
list[str]: List of filenames currently in use by the page text
"""
files_in_used = []
soup = bs4.BeautifulSoup(text, "html.parser")
for attachment in soup.findAll("ri:attachment"):
files_in_used.append(attachment.attrs["ri:filename"])
return files_in_used

View File

@@ -22,6 +22,10 @@ from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.confluence.confluence_utils import (
build_confluence_document_id,
)
from danswer.connectors.confluence.confluence_utils import get_used_attachments
from danswer.connectors.confluence.rate_limit_handler import (
make_confluence_call_handle_rate_limit,
)
@@ -105,24 +109,6 @@ def parse_html_page(text: str, confluence_client: Confluence) -> str:
return format_document_soup(soup)
def get_used_attachments(text: str, confluence_client: Confluence) -> list[str]:
"""Parse a Confluence html page to generate a list of current
attachment in used
Args:
text (str): The page content
confluence_client (Confluence): Confluence client
Returns:
list[str]: List of filename currently in used
"""
files_in_used = []
soup = bs4.BeautifulSoup(text, "html.parser")
for attachment in soup.findAll("ri:attachment"):
files_in_used.append(attachment.attrs["ri:filename"])
return files_in_used
def _comment_dfs(
comments_str: str,
comment_pages: Collection[dict[str, Any]],
@@ -533,7 +519,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
return None
extracted_text = extract_file_text(
attachment["title"], io.BytesIO(response.content), False
io.BytesIO(response.content),
file_name=attachment["title"],
break_on_unprocessable=False,
)
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
logger.warning(
@@ -624,19 +612,22 @@ class ConfluenceConnector(LoadConnector, PollConnector):
page_html = (
page["body"].get("storage", page["body"].get("view", {})).get("value")
)
page_url = self.wiki_base + page["_links"]["webui"]
# The url and the id are the same
page_url = build_confluence_document_id(
self.wiki_base, page["_links"]["webui"]
)
if not page_html:
logger.debug("Page is empty, skipping: %s", page_url)
continue
page_text = parse_html_page(page_html, self.confluence_client)
files_in_used = get_used_attachments(page_html, self.confluence_client)
files_in_used = get_used_attachments(page_html)
attachment_text, unused_page_attachments = self._fetch_attachments(
self.confluence_client, page_id, files_in_used
)
unused_attachments.extend(unused_page_attachments)
page_text += attachment_text
page_text += "\n" + attachment_text if attachment_text else ""
comments_text = self._fetch_comments(self.confluence_client, page_id)
page_text += comments_text
doc_metadata: dict[str, str | list[str]] = {"Wiki Space Name": self.space}
@@ -683,8 +674,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
if time_filter and not time_filter(last_updated):
continue
attachment_url = self._attachment_to_download_link(
self.confluence_client, attachment
# The url and the id are the same
attachment_url = build_confluence_document_id(
self.wiki_base, attachment["_links"]["download"]
)
attachment_content = self._attachment_to_content(
self.confluence_client, attachment

View File

@@ -50,6 +50,12 @@ def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
pass
if retry_after is not None:
if retry_after > 600:
logger.warning(
f"Clamping retry_after from {retry_after} to {max_delay} seconds..."
)
retry_after = max_delay
logger.warning(
f"Rate limit hit. Retrying after {retry_after} seconds..."
)

View File

@@ -9,6 +9,7 @@ from jira.resources import Issue
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
@@ -134,10 +135,18 @@ def fetch_jira_issues_batch(
else extract_text_from_adf(jira.raw["fields"]["description"])
)
comments = _get_comment_strs(jira, comment_email_blacklist)
semantic_rep = f"{description}\n" + "\n".join(
ticket_content = f"{description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments if comment]
)
# Check ticket size
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
logger.info(
f"Skipping {jira.key} because it exceeds the maximum size of "
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
)
continue
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
people = set()
@@ -180,7 +189,7 @@ def fetch_jira_issues_batch(
doc_batch.append(
Document(
id=page_url,
sections=[Section(link=page_url, text=semantic_rep)],
sections=[Section(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=jira.fields.summary,
doc_updated_at=time_str_to_utc(jira.fields.updated),
@@ -236,10 +245,12 @@ class JiraConnector(LoadConnector, PollConnector):
if self.jira_client is None:
raise ConnectorMissingCredentialError("Jira")
# Quote the project name to handle reserved words
quoted_project = f'"{self.jira_project}"'
start_ind = 0
while True:
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
jql=f"project = {self.jira_project}",
jql=f"project = {quoted_project}",
start_index=start_ind,
jira_client=self.jira_client,
batch_size=self.batch_size,
@@ -267,8 +278,10 @@ class JiraConnector(LoadConnector, PollConnector):
"%Y-%m-%d %H:%M"
)
# Quote the project name to handle reserved words
quoted_project = f'"{self.jira_project}"'
jql = (
f"project = {self.jira_project} AND "
f"project = {quoted_project} AND "
f"updated >= '{start_date_str}' AND "
f"updated <= '{end_date_str}'"
)

View File

@@ -97,8 +97,8 @@ class DropboxConnector(LoadConnector, PollConnector):
link = self._get_shared_link(entry.path_display)
try:
text = extract_file_text(
entry.name,
BytesIO(downloaded_file),
file_name=entry.name,
break_on_unprocessable=False,
)
batch.append(

View File

@@ -42,6 +42,7 @@ from danswer.connectors.slack.load_connector import SlackLoadConnector
from danswer.connectors.teams.connector import TeamsConnector
from danswer.connectors.web.connector import WebConnector
from danswer.connectors.wikipedia.connector import WikipediaConnector
from danswer.connectors.xenforo.connector import XenforoConnector
from danswer.connectors.zendesk.connector import ZendeskConnector
from danswer.connectors.zulip.connector import ZulipConnector
from danswer.db.credentials import backend_update_credential_json
@@ -62,6 +63,7 @@ def identify_connector_class(
DocumentSource.SLACK: {
InputType.LOAD_STATE: SlackLoadConnector,
InputType.POLL: SlackPollConnector,
InputType.PRUNE: SlackPollConnector,
},
DocumentSource.GITHUB: GithubConnector,
DocumentSource.GMAIL: GmailConnector,
@@ -97,6 +99,7 @@ def identify_connector_class(
DocumentSource.R2: BlobStorageConnector,
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
DocumentSource.OCI_STORAGE: BlobStorageConnector,
DocumentSource.XENFORO: XenforoConnector,
}
connector_by_source = connector_map.get(source, {})

View File

@@ -74,13 +74,14 @@ def _process_file(
)
# Using the PDF reader function directly to pass in password cleanly
elif extension == ".pdf":
elif extension == ".pdf" and pdf_pass is not None:
file_content_raw, file_metadata = read_pdf_file(file=file, pdf_pass=pdf_pass)
else:
file_content_raw = extract_file_text(
file_name=file_name,
file=file,
file_name=file_name,
break_on_unprocessable=True,
)
all_metadata = {**metadata, **file_metadata} if metadata else file_metadata

View File

@@ -25,7 +25,7 @@ from danswer.connectors.gmail.constants import (
from danswer.connectors.gmail.constants import SCOPES
from danswer.db.credentials import update_credential_json
from danswer.db.models import User
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.key_value_store.factory import get_kv_store
from danswer.server.documents.models import CredentialBase
from danswer.server.documents.models import GoogleAppCredentials
from danswer.server.documents.models import GoogleServiceAccountKey
@@ -72,7 +72,7 @@ def get_gmail_creds_for_service_account(
def verify_csrf(credential_id: int, state: str) -> None:
csrf = get_dynamic_config_store().load(KV_CRED_KEY.format(str(credential_id)))
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
if csrf != state:
raise PermissionError(
"State from Gmail Connector callback does not match expected"
@@ -80,7 +80,7 @@ def verify_csrf(credential_id: int, state: str) -> None:
def get_gmail_auth_url(credential_id: int) -> str:
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_CRED_KEY))
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
@@ -92,14 +92,14 @@ def get_gmail_auth_url(credential_id: int) -> str:
parsed_url = cast(ParseResult, urlparse(auth_url))
params = parse_qs(parsed_url.query)
get_dynamic_config_store().store(
get_kv_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
return str(auth_url)
def get_auth_url(credential_id: int) -> str:
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_CRED_KEY))
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
@@ -111,7 +111,7 @@ def get_auth_url(credential_id: int) -> str:
parsed_url = cast(ParseResult, urlparse(auth_url))
params = parse_qs(parsed_url.query)
get_dynamic_config_store().store(
get_kv_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
return str(auth_url)
@@ -158,42 +158,40 @@ def build_service_account_creds(
def get_google_app_gmail_cred() -> GoogleAppCredentials:
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_CRED_KEY))
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
return GoogleAppCredentials(**json.loads(creds_str))
def upsert_google_app_gmail_cred(app_credentials: GoogleAppCredentials) -> None:
get_dynamic_config_store().store(
KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True
)
get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True)
def delete_google_app_gmail_cred() -> None:
get_dynamic_config_store().delete(KV_GMAIL_CRED_KEY)
get_kv_store().delete(KV_GMAIL_CRED_KEY)
def get_gmail_service_account_key() -> GoogleServiceAccountKey:
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
return GoogleServiceAccountKey(**json.loads(creds_str))
def upsert_gmail_service_account_key(
service_account_key: GoogleServiceAccountKey,
) -> None:
get_dynamic_config_store().store(
get_kv_store().store(
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
get_dynamic_config_store().store(
get_kv_store().store(
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
def delete_gmail_service_account_key() -> None:
get_dynamic_config_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
get_kv_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
def delete_service_account_key() -> None:
get_dynamic_config_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
get_kv_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)

View File

@@ -36,6 +36,8 @@ from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import docx_to_text
from danswer.file_processing.extract_file_text import pptx_to_text
from danswer.file_processing.extract_file_text import read_pdf_file
from danswer.file_processing.unstructured import get_unstructured_api_key
from danswer.file_processing.unstructured import unstructured_to_text
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
@@ -327,16 +329,24 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
GDriveMimeType.MARKDOWN.value,
]:
return service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
elif mime_type == GDriveMimeType.WORD_DOC.value:
if mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
GDriveMimeType.PDF.value,
]:
response = service.files().get_media(fileId=file["id"]).execute()
return docx_to_text(file=io.BytesIO(response))
elif mime_type == GDriveMimeType.PDF.value:
response = service.files().get_media(fileId=file["id"]).execute()
text, _ = read_pdf_file(file=io.BytesIO(response))
return text
elif mime_type == GDriveMimeType.POWERPOINT.value:
response = service.files().get_media(fileId=file["id"]).execute()
return pptx_to_text(file=io.BytesIO(response))
if get_unstructured_api_key():
return unstructured_to_text(
file=io.BytesIO(response), file_name=file.get("name", file["id"])
)
if mime_type == GDriveMimeType.WORD_DOC.value:
return docx_to_text(file=io.BytesIO(response))
elif mime_type == GDriveMimeType.PDF.value:
text, _ = read_pdf_file(file=io.BytesIO(response))
return text
elif mime_type == GDriveMimeType.POWERPOINT.value:
return pptx_to_text(file=io.BytesIO(response))
return UNSUPPORTED_FILE_TYPE_CONTENT

View File

@@ -28,7 +28,7 @@ from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
from danswer.db.credentials import update_credential_json
from danswer.db.models import User
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.key_value_store.factory import get_kv_store
from danswer.server.documents.models import CredentialBase
from danswer.server.documents.models import GoogleAppCredentials
from danswer.server.documents.models import GoogleServiceAccountKey
@@ -134,7 +134,7 @@ def get_google_drive_creds(
def verify_csrf(credential_id: int, state: str) -> None:
csrf = get_dynamic_config_store().load(KV_CRED_KEY.format(str(credential_id)))
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
if csrf != state:
raise PermissionError(
"State from Google Drive Connector callback does not match expected"
@@ -142,7 +142,7 @@ def verify_csrf(credential_id: int, state: str) -> None:
def get_auth_url(credential_id: int) -> str:
creds_str = str(get_dynamic_config_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
@@ -154,7 +154,7 @@ def get_auth_url(credential_id: int) -> str:
parsed_url = cast(ParseResult, urlparse(auth_url))
params = parse_qs(parsed_url.query)
get_dynamic_config_store().store(
get_kv_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
return str(auth_url)
@@ -202,32 +202,28 @@ def build_service_account_creds(
def get_google_app_cred() -> GoogleAppCredentials:
creds_str = str(get_dynamic_config_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
return GoogleAppCredentials(**json.loads(creds_str))
def upsert_google_app_cred(app_credentials: GoogleAppCredentials) -> None:
get_dynamic_config_store().store(
KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
)
get_kv_store().store(KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True)
def delete_google_app_cred() -> None:
get_dynamic_config_store().delete(KV_GOOGLE_DRIVE_CRED_KEY)
get_kv_store().delete(KV_GOOGLE_DRIVE_CRED_KEY)
def get_service_account_key() -> GoogleServiceAccountKey:
creds_str = str(
get_dynamic_config_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
)
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
return GoogleServiceAccountKey(**json.loads(creds_str))
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
get_dynamic_config_store().store(
get_kv_store().store(
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
def delete_service_account_key() -> None:
get_dynamic_config_store().delete(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
get_kv_store().delete(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)

View File

@@ -40,8 +40,8 @@ def _convert_driveitem_to_document(
driveitem: DriveItem,
) -> Document:
file_text = extract_file_text(
file_name=driveitem.name,
file=io.BytesIO(driveitem.get_content().execute_query().value),
file_name=driveitem.name,
break_on_unprocessable=False,
)

View File

@@ -8,13 +8,12 @@ from typing import cast
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.web import SlackResponse
from danswer.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import IdConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
@@ -23,9 +22,8 @@ from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.slack.utils import expert_info_from_slack_id
from danswer.connectors.slack.utils import get_message_link
from danswer.connectors.slack.utils import make_slack_api_call_logged
from danswer.connectors.slack.utils import make_slack_api_call_paginated
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import make_paginated_slack_api_call_w_retries
from danswer.connectors.slack.utils import make_slack_api_call_w_retries
from danswer.connectors.slack.utils import SlackTextCleaner
from danswer.utils.logger import setup_logger
@@ -38,47 +36,18 @@ MessageType = dict[str, Any]
# list of messages in a thread
ThreadType = list[MessageType]
basic_retry_wrapper = retry_builder()
def _make_paginated_slack_api_call(
call: Callable[..., SlackResponse], **kwargs: Any
) -> Generator[dict[str, Any], None, None]:
return make_slack_api_call_paginated(
basic_retry_wrapper(
make_slack_api_rate_limited(make_slack_api_call_logged(call))
)
)(**kwargs)
def _make_slack_api_call(
call: Callable[..., SlackResponse], **kwargs: Any
) -> SlackResponse:
return basic_retry_wrapper(
make_slack_api_rate_limited(make_slack_api_call_logged(call))
)(**kwargs)
def get_channel_info(client: WebClient, channel_id: str) -> ChannelType:
"""Get information about a channel. Needed to convert channel ID to channel name"""
return _make_slack_api_call(client.conversations_info, channel=channel_id)[0][
"channel"
]
def _get_channels(
def _collect_paginated_channels(
client: WebClient,
exclude_archived: bool,
get_private: bool,
channel_types: list[str],
) -> list[ChannelType]:
channels: list[dict[str, Any]] = []
for result in _make_paginated_slack_api_call(
for result in make_paginated_slack_api_call_w_retries(
client.conversations_list,
exclude_archived=exclude_archived,
# also get private channels the bot is added to
types=["public_channel", "private_channel"]
if get_private
else ["public_channel"],
types=channel_types,
):
channels.extend(result["channels"])
@@ -88,19 +57,38 @@ def _get_channels(
def get_channels(
client: WebClient,
exclude_archived: bool = True,
get_public: bool = True,
get_private: bool = True,
) -> list[ChannelType]:
"""Get all channels in the workspace"""
channels: list[dict[str, Any]] = []
channel_types = []
if get_public:
channel_types.append("public_channel")
if get_private:
channel_types.append("private_channel")
# try getting private channels as well at first
try:
return _get_channels(
client=client, exclude_archived=exclude_archived, get_private=True
channels = _collect_paginated_channels(
client=client,
exclude_archived=exclude_archived,
channel_types=channel_types,
)
except SlackApiError as e:
logger.info(f"Unable to fetch private channels due to - {e}")
logger.info("trying again without private channels")
if get_public:
channel_types = ["public_channel"]
else:
logger.warning("No channels to fetch")
return []
channels = _collect_paginated_channels(
client=client,
exclude_archived=exclude_archived,
channel_types=channel_types,
)
return _get_channels(
client=client, exclude_archived=exclude_archived, get_private=False
)
return channels
def get_channel_messages(
@@ -112,14 +100,14 @@ def get_channel_messages(
"""Get all messages in a channel"""
# join so that the bot can access messages
if not channel["is_member"]:
_make_slack_api_call(
make_slack_api_call_w_retries(
client.conversations_join,
channel=channel["id"],
is_private=channel["is_private"],
)
logger.info(f"Successfully joined '{channel['name']}'")
for result in _make_paginated_slack_api_call(
for result in make_paginated_slack_api_call_w_retries(
client.conversations_history,
channel=channel["id"],
oldest=oldest,
@@ -131,7 +119,7 @@ def get_channel_messages(
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
"""Get all messages in a thread"""
threads: list[MessageType] = []
for result in _make_paginated_slack_api_call(
for result in make_paginated_slack_api_call_w_retries(
client.conversations_replies, channel=channel_id, ts=thread_id
):
threads.extend(result["messages"])
@@ -217,12 +205,17 @@ _DISALLOWED_MSG_SUBTYPES = {
"group_leave",
"group_archive",
"group_unarchive",
"channel_leave",
"channel_name",
"channel_join",
}
def _default_msg_filter(message: MessageType) -> bool:
def default_msg_filter(message: MessageType) -> bool:
# Don't keep messages from bots
if message.get("bot_id") or message.get("app_id"):
if message.get("bot_profile", {}).get("name") == "DanswerConnector":
return False
return True
# Uninformative
@@ -266,14 +259,14 @@ def filter_channels(
]
def get_all_docs(
def _get_all_docs(
client: WebClient,
workspace: str,
channels: list[str] | None = None,
channel_name_regex_enabled: bool = False,
oldest: str | None = None,
latest: str | None = None,
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> Generator[Document, None, None]:
"""Get all documents in the workspace, channel by channel"""
slack_cleaner = SlackTextCleaner(client=client)
@@ -328,7 +321,44 @@ def get_all_docs(
)
class SlackPollConnector(PollConnector):
def _get_all_doc_ids(
client: WebClient,
channels: list[str] | None = None,
channel_name_regex_enabled: bool = False,
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> set[str]:
"""
Get all document ids in the workspace, channel by channel
This is pretty identical to get_all_docs, but it returns a set of ids instead of documents
This makes it an order of magnitude faster than get_all_docs
"""
all_channels = get_channels(client)
filtered_channels = filter_channels(
all_channels, channels, channel_name_regex_enabled
)
all_doc_ids = set()
for channel in filtered_channels:
channel_message_batches = get_channel_messages(
client=client,
channel=channel,
)
for message_batch in channel_message_batches:
for message in message_batch:
if msg_filter_func(message):
continue
# The document id is the channel id and the ts of the first message in the thread
# Since we already have the first message of the thread, we dont have to
# fetch the thread for id retrieval, saving time and API calls
all_doc_ids.add(f"{channel['id']}__{message['ts']}")
return all_doc_ids
class SlackPollConnector(PollConnector, IdConnector):
def __init__(
self,
workspace: str,
@@ -349,6 +379,16 @@ class SlackPollConnector(PollConnector):
self.client = WebClient(token=bot_token)
return None
def retrieve_all_source_ids(self) -> set[str]:
if self.client is None:
raise ConnectorMissingCredentialError("Slack")
return _get_all_doc_ids(
client=self.client,
channels=self.channels,
channel_name_regex_enabled=self.channel_regex_enabled,
)
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
@@ -356,7 +396,7 @@ class SlackPollConnector(PollConnector):
raise ConnectorMissingCredentialError("Slack")
documents: list[Document] = []
for document in get_all_docs(
for document in _get_all_docs(
client=self.client,
workspace=self.workspace,
channels=self.channels,

View File

@@ -10,11 +10,13 @@ from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.web import SlackResponse
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.models import BasicExpertInfo
from danswer.utils.logger import setup_logger
logger = setup_logger()
basic_retry_wrapper = retry_builder()
# number of messages we request per page when fetching paginated slack messages
_SLACK_LIMIT = 900
@@ -34,7 +36,7 @@ def get_message_link(
)
def make_slack_api_call_logged(
def _make_slack_api_call_logged(
call: Callable[..., SlackResponse],
) -> Callable[..., SlackResponse]:
@wraps(call)
@@ -47,7 +49,7 @@ def make_slack_api_call_logged(
return logged_call
def make_slack_api_call_paginated(
def _make_slack_api_call_paginated(
call: Callable[..., SlackResponse],
) -> Callable[..., Generator[dict[str, Any], None, None]]:
"""Wraps calls to slack API so that they automatically handle pagination"""
@@ -116,6 +118,24 @@ def make_slack_api_rate_limited(
return rate_limited_call
def make_slack_api_call_w_retries(
call: Callable[..., SlackResponse], **kwargs: Any
) -> SlackResponse:
return basic_retry_wrapper(
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
)(**kwargs)
def make_paginated_slack_api_call_w_retries(
call: Callable[..., SlackResponse], **kwargs: Any
) -> Generator[dict[str, Any], None, None]:
return _make_slack_api_call_paginated(
basic_retry_wrapper(
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
)
)(**kwargs)
def expert_info_from_slack_id(
user_id: str | None,
client: WebClient,

View File

@@ -0,0 +1,244 @@
"""
This is the XenforoConnector class. It is used to connect to a Xenforo forum and load or update documents from the forum.
To use this class, you need to provide the URL of the Xenforo forum board you want to connect to when creating an instance
of the class. The URL should be a string that starts with 'http://' or 'https://', followed by the domain name of the
forum, followed by the board name. For example:
base_url = 'https://www.example.com/forum/boards/some-topic/'
The `load_from_state` method is used to load documents from the forum. It takes an optional `state` parameter, which
can be used to specify a state from which to start loading documents.
"""
import re
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from urllib.parse import urlparse
import pytz
import requests
from bs4 import BeautifulSoup
from bs4 import Tag
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import datetime_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_title(soup: BeautifulSoup) -> str:
el = soup.find("h1", "p-title-value")
if not el:
return ""
title = el.text
for char in (";", ":", "!", "*", "/", "\\", "?", '"', "<", ">", "|"):
title = title.replace(char, "_")
return title
def get_pages(soup: BeautifulSoup, url: str) -> list[str]:
page_tags = soup.select("li.pageNav-page")
page_numbers = []
for button in page_tags:
if re.match(r"^\d+$", button.text):
page_numbers.append(button.text)
max_pages = int(max(page_numbers, key=int)) if page_numbers else 1
all_pages = []
for x in range(1, int(max_pages) + 1):
all_pages.append(f"{url}page-{x}")
return all_pages
def parse_post_date(post_element: BeautifulSoup) -> datetime:
el = post_element.find("time")
if not isinstance(el, Tag) or "datetime" not in el.attrs:
return datetime.utcfromtimestamp(0).replace(tzinfo=timezone.utc)
date_value = el["datetime"]
# Ensure date_value is a string (if it's a list, take the first element)
if isinstance(date_value, list):
date_value = date_value[0]
post_date = datetime.strptime(date_value, "%Y-%m-%dT%H:%M:%S%z")
return datetime_to_utc(post_date)
def scrape_page_posts(
soup: BeautifulSoup,
page_index: int,
url: str,
initial_run: bool,
start_time: datetime,
) -> list:
title = get_title(soup)
documents = []
for post in soup.find_all("div", class_="message-inner"):
post_date = parse_post_date(post)
if initial_run or post_date > start_time:
el = post.find("div", class_="bbWrapper")
if not el:
continue
post_text = el.get_text(strip=True) + "\n"
author_tag = post.find("a", class_="username")
if author_tag is None:
author_tag = post.find("span", class_="username")
author = author_tag.get_text(strip=True) if author_tag else "Deleted author"
formatted_time = post_date.strftime("%Y-%m-%d %H:%M:%S")
# TODO: if a caller calls this for each page of a thread, it may see the
# same post multiple times if there is a sticky post
# that appears on each page of a thread.
# it's important to generate unique doc id's, so page index is part of the
# id. We may want to de-dupe this stuff inside the indexing service.
document = Document(
id=f"{DocumentSource.XENFORO.value}_{title}_{page_index}_{formatted_time}",
sections=[Section(link=url, text=post_text)],
title=title,
source=DocumentSource.XENFORO,
semantic_identifier=title,
primary_owners=[BasicExpertInfo(display_name=author)],
metadata={
"type": "post",
"author": author,
"time": formatted_time,
},
doc_updated_at=post_date,
)
documents.append(document)
return documents
class XenforoConnector(LoadConnector):
# Class variable to track if the connector has been run before
has_been_run_before = False
def __init__(self, base_url: str) -> None:
self.base_url = base_url
self.initial_run = not XenforoConnector.has_been_run_before
self.start = datetime.utcnow().replace(tzinfo=pytz.utc) - timedelta(days=1)
self.cookies: dict[str, str] = {}
# mimic user browser to avoid being blocked by the website (see: https://www.useragents.me/)
self.headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/121.0.0.0 Safari/537.36"
}
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
if credentials:
logger.warning("Unexpected credentials provided for Xenforo Connector")
return None
def load_from_state(self) -> GenerateDocumentsOutput:
# Standardize URL to always end in /.
if self.base_url[-1] != "/":
self.base_url += "/"
# Remove all extra parameters from the end such as page, post.
matches = ("threads/", "boards/", "forums/")
for each in matches:
if each in self.base_url:
try:
self.base_url = self.base_url[
0 : self.base_url.index(
"/", self.base_url.index(each) + len(each)
)
+ 1
]
except ValueError:
pass
doc_batch: list[Document] = []
all_threads = []
# If the URL contains "boards/" or "forums/", find all threads.
if "boards/" in self.base_url or "forums/" in self.base_url:
pages = get_pages(self.requestsite(self.base_url), self.base_url)
# Get all pages on thread_list_page
for pre_count, thread_list_page in enumerate(pages, start=1):
logger.info(
f"Getting pages from thread_list_page.. Current: {pre_count}/{len(pages)}\r"
)
all_threads += self.get_threads(thread_list_page)
# If the URL contains "threads/", add the thread to the list.
elif "threads/" in self.base_url:
all_threads.append(self.base_url)
# Process all threads
for thread_count, thread_url in enumerate(all_threads, start=1):
soup = self.requestsite(thread_url)
if soup is None:
logger.error(f"Failed to load page: {self.base_url}")
continue
pages = get_pages(soup, thread_url)
# Getting all pages for all threads
for page_index, page in enumerate(pages, start=1):
logger.info(
f"Progress: Page {page_index}/{len(pages)} - Thread {thread_count}/{len(all_threads)}\r"
)
soup_page = self.requestsite(page)
doc_batch.extend(
scrape_page_posts(
soup_page, page_index, thread_url, self.initial_run, self.start
)
)
if doc_batch:
yield doc_batch
# Mark the initial run finished after all threads and pages have been processed
XenforoConnector.has_been_run_before = True
def get_threads(self, url: str) -> list[str]:
soup = self.requestsite(url)
thread_tags = soup.find_all(class_="structItem-title")
base_url = "{uri.scheme}://{uri.netloc}".format(uri=urlparse(url))
threads = []
for x in thread_tags:
y = x.find_all(href=True)
for element in y:
link = element["href"]
if "threads/" in link:
stripped = link[0 : link.rfind("/") + 1]
if base_url + stripped not in threads:
threads.append(base_url + stripped)
return threads
def requestsite(self, url: str) -> BeautifulSoup:
try:
response = requests.get(
url, cookies=self.cookies, headers=self.headers, timeout=10
)
if response.status_code != 200:
logger.error(
f"<{url}> Request Error: {response.status_code} - {response.reason}"
)
return BeautifulSoup(response.text, "html.parser")
except TimeoutError:
logger.error("Timed out Error.")
except Exception as e:
logger.error(f"Error on {url}")
logger.exception(e)
return BeautifulSoup("", "html.parser")
if __name__ == "__main__":
connector = XenforoConnector(
# base_url="https://cassiopaea.org/forum/threads/how-to-change-your-emotional-state.41381/"
base_url="https://xenforo.com/community/threads/whats-new-with-enhanced-search-resource-manager-and-media-gallery-in-xenforo-2-3.220935/"
)
document_batches = connector.load_from_state()
print(next(document_batches))

View File

@@ -160,7 +160,7 @@ def handle_regular_answer(
detail="Slack bot does not support persona config",
)
elif new_message_request.persona_id:
elif new_message_request.persona_id is not None:
persona = cast(
Persona,
fetch_persona_by_id(

View File

@@ -49,7 +49,7 @@ from danswer.danswerbot.slack.utils import rephrase_slack_message
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.search_settings import get_current_search_settings
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.one_shot_answer.models import ThreadMessage
@@ -131,9 +131,8 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
)
return False
bot_tag_id = get_danswer_bot_app_id(client.web_client)
if event_type == "message":
bot_tag_id = get_danswer_bot_app_id(client.web_client)
is_dm = event.get("channel_type") == "im"
is_tagged = bot_tag_id and bot_tag_id in msg
is_danswer_bot_msg = bot_tag_id and bot_tag_id in event.get("user", "")
@@ -159,8 +158,10 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session
)
if not slack_bot_config or not slack_bot_config.channel_config.get(
"respond_to_bots"
# If DanswerBot is not specifically tagged and the channel is not set to respond to bots, ignore the message
if (not bot_tag_id or bot_tag_id not in msg) and (
not slack_bot_config
or not slack_bot_config.channel_config.get("respond_to_bots")
):
channel_specific_logger.info("Ignoring message from bot")
return False
@@ -447,8 +448,9 @@ def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> Non
return view_routing(req, client)
elif req.type == "events_api" or req.type == "slash_commands":
return process_message(req, client)
except Exception:
logger.exception("Failed to process slack event")
except Exception as e:
logger.exception(f"Failed to process slack event. Error: {e}")
logger.error(f"Slack request payload: {req.payload}")
def _get_socket_client(slack_bot_tokens: SlackBotTokens) -> SocketModeClient:
@@ -522,7 +524,7 @@ if __name__ == "__main__":
# Let the handlers run in the background + re-check for token updates every 60 seconds
Event().wait(timeout=60)
except ConfigNotFoundError:
except KvKeyNotFoundError:
# try again every 30 seconds. This is needed since the user may add tokens
# via the UI at any point in the programs lifecycle - if we just allow it to
# fail, then the user will need to restart the containers after adding tokens

View File

@@ -2,7 +2,7 @@ import os
from typing import cast
from danswer.configs.constants import KV_SLACK_BOT_TOKENS_CONFIG_KEY
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.key_value_store.factory import get_kv_store
from danswer.server.manage.models import SlackBotTokens
@@ -13,7 +13,7 @@ def fetch_tokens() -> SlackBotTokens:
if app_token and bot_token:
return SlackBotTokens(app_token=app_token, bot_token=bot_token)
dynamic_config_store = get_dynamic_config_store()
dynamic_config_store = get_kv_store()
return SlackBotTokens(
**cast(dict, dynamic_config_store.load(key=KV_SLACK_BOT_TOKENS_CONFIG_KEY))
)
@@ -22,7 +22,7 @@ def fetch_tokens() -> SlackBotTokens:
def save_tokens(
tokens: SlackBotTokens,
) -> None:
dynamic_config_store = get_dynamic_config_store()
dynamic_config_store = get_kv_store()
dynamic_config_store.store(
key=KV_SLACK_BOT_TOKENS_CONFIG_KEY, val=dict(tokens), encrypt=True
)

View File

@@ -430,35 +430,58 @@ def read_slack_thread(
replies = cast(dict, response.data).get("messages", [])
for reply in replies:
if "user" in reply and "bot_id" not in reply:
message = remove_danswer_bot_tag(reply["text"], client=client)
user_sem_id = fetch_user_semantic_id_from_id(reply["user"], client)
message = reply["text"]
user_sem_id = (
fetch_user_semantic_id_from_id(reply.get("user"), client)
or "Unknown User"
)
message_type = MessageType.USER
else:
self_app_id = get_danswer_bot_app_id(client)
# Only include bot messages from Danswer, other bots are not taken in as context
if self_app_id != reply.get("user"):
continue
if reply.get("user") == self_app_id:
# DanswerBot response
message_type = MessageType.ASSISTANT
user_sem_id = "Assistant"
blocks = reply["blocks"]
if len(blocks) <= 1:
continue
# For the old flow, the useful block is the second one after the header block that says AI Answer
if reply["blocks"][0]["text"]["text"] == "AI Answer":
message = reply["blocks"][1]["text"]["text"]
else:
# for the new flow, the answer is the first block
message = reply["blocks"][0]["text"]["text"]
if message.startswith("_Filters"):
if len(blocks) <= 2:
# DanswerBot responses have both text and blocks
# The useful content is in the blocks, specifically the first block unless there are
# auto-detected filters
blocks = reply.get("blocks")
if not blocks:
logger.warning(f"DanswerBot response has no blocks: {reply}")
continue
message = reply["blocks"][2]["text"]["text"]
user_sem_id = "Assistant"
message_type = MessageType.ASSISTANT
message = blocks[0].get("text", {}).get("text")
# If auto-detected filters are on, use the second block for the actual answer
# The first block is the auto-detected filters
if message.startswith("_Filters"):
if len(blocks) < 2:
logger.warning(f"Only filter blocks found: {reply}")
continue
# This is the DanswerBot answer format, if there is a change to how we respond,
# this will need to be updated to get the correct "answer" portion
message = reply["blocks"][1].get("text", {}).get("text")
else:
# Other bots are not counted as the LLM response which only comes from Danswer
message_type = MessageType.USER
bot_user_name = fetch_user_semantic_id_from_id(
reply.get("user"), client
)
user_sem_id = bot_user_name or "Unknown" + " Bot"
# For other bots, just use the text as we have no way of knowing that the
# useful portion is
message = reply.get("text")
if not message:
message = blocks[0].get("text", {}).get("text")
if not message:
logger.warning("Skipping Slack thread message, no text found")
continue
message = remove_danswer_bot_tag(message, client=client)
thread_messages.append(
ThreadMessage(message=message, sender=user_sem_id, role=message_type)
)

View File

@@ -1,3 +1,5 @@
from datetime import datetime
from datetime import timezone
from typing import cast
from sqlalchemy import and_
@@ -268,3 +270,15 @@ def create_initial_default_connector(db_session: Session) -> None:
)
db_session.add(connector)
db_session.commit()
def mark_ccpair_as_pruned(cc_pair_id: int, db_session: Session) -> None:
stmt = select(ConnectorCredentialPair).where(
ConnectorCredentialPair.id == cc_pair_id
)
cc_pair = db_session.scalar(stmt)
if cc_pair is None:
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
cc_pair.last_pruned = datetime.now(timezone.utc)
db_session.commit()

View File

@@ -26,9 +26,7 @@ from danswer.db.models import UserRole
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit
from ee.danswer.external_permissions.permission_sync_function_map import (
check_if_valid_sync_source,
)
from ee.danswer.external_permissions.sync_params import check_if_valid_sync_source
logger = setup_logger()

View File

@@ -104,6 +104,18 @@ def construct_document_select_for_connector_credential_pair(
return stmt
def get_document_ids_for_connector_credential_pair(
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
) -> list[str]:
doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
)
return list(db_session.execute(doc_ids_stmt).scalars().all())
def get_documents_for_connector_credential_pair(
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
) -> Sequence[DbDocument]:
@@ -120,8 +132,8 @@ def get_documents_for_connector_credential_pair(
def get_documents_by_ids(
document_ids: list[str],
db_session: Session,
document_ids: list[str],
) -> list[DbDocument]:
stmt = select(DbDocument).where(DbDocument.id.in_(document_ids))
documents = db_session.execute(stmt).scalars().all()

View File

@@ -1,10 +1,18 @@
import contextlib
import contextvars
import re
import threading
import time
from collections.abc import AsyncGenerator
from collections.abc import Generator
from datetime import datetime
from typing import Any
from typing import ContextManager
import jwt
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import event
from sqlalchemy import text
from sqlalchemy.engine import create_engine
@@ -17,6 +25,9 @@ from sqlalchemy.orm import sessionmaker
from danswer.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
from danswer.configs.app_configs import LOG_POSTGRES_LATENCY
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
from danswer.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
from danswer.configs.app_configs import POSTGRES_DB
from danswer.configs.app_configs import POSTGRES_HOST
from danswer.configs.app_configs import POSTGRES_PASSWORD
@@ -24,27 +35,24 @@ from danswer.configs.app_configs import POSTGRES_POOL_PRE_PING
from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE
from danswer.configs.app_configs import POSTGRES_PORT
from danswer.configs.app_configs import POSTGRES_USER
from danswer.configs.app_configs import SECRET_JWT_KEY
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from danswer.utils.logger import setup_logger
logger = setup_logger()
SYNC_DB_API = "psycopg2"
ASYNC_DB_API = "asyncpg"
POSTGRES_APP_NAME = (
POSTGRES_UNKNOWN_APP_NAME # helps to diagnose open connections in postgres
)
# global so we don't create more than one engine per process
# outside of being best practice, this is needed so we can properly pool
# connections and not create a new pool on every request
_SYNC_ENGINE: Engine | None = None
_ASYNC_ENGINE: AsyncEngine | None = None
SessionFactory: sessionmaker[Session] | None = None
if LOG_POSTGRES_LATENCY:
# Function to log before query execution
@event.listens_for(Engine, "before_cursor_execute")
@@ -108,6 +116,78 @@ def get_db_current_time(db_session: Session) -> datetime:
return result
# Regular expression to validate schema names to prevent SQL injection
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
def is_valid_schema_name(name: str) -> bool:
return SCHEMA_NAME_REGEX.match(name) is not None
class SqlEngine:
"""Class to manage a global SQLAlchemy engine (needed for proper resource control).
Will eventually subsume most of the standalone functions in this file.
Sync only for now.
"""
_engine: Engine | None = None
_lock: threading.Lock = threading.Lock()
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
# Default parameters for engine creation
DEFAULT_ENGINE_KWARGS = {
"pool_size": 20,
"max_overflow": 5,
"pool_pre_ping": POSTGRES_POOL_PRE_PING,
"pool_recycle": POSTGRES_POOL_RECYCLE,
}
def __init__(self) -> None:
pass
@classmethod
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
"""Private helper method to create and return an Engine."""
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync"
)
merged_kwargs = {**cls.DEFAULT_ENGINE_KWARGS, **engine_kwargs}
return create_engine(connection_string, **merged_kwargs)
@classmethod
def init_engine(cls, **engine_kwargs: Any) -> None:
"""Allow the caller to init the engine with extra params. Different clients
such as the API server and different Celery workers and tasks
need different settings.
"""
with cls._lock:
if not cls._engine:
cls._engine = cls._init_engine(**engine_kwargs)
@classmethod
def get_engine(cls) -> Engine:
"""Gets the SQLAlchemy engine. Will init a default engine if init hasn't
already been called. You probably want to init first!
"""
if not cls._engine:
with cls._lock:
if not cls._engine:
cls._engine = cls._init_engine()
return cls._engine
@classmethod
def set_app_name(cls, app_name: str) -> None:
"""Class method to set the app name."""
cls._app_name = app_name
@classmethod
def get_app_name(cls) -> str:
"""Class method to get current app name."""
if not cls._app_name:
return ""
return cls._app_name
def build_connection_string(
*,
db_api: str = ASYNC_DB_API,
@@ -120,69 +200,141 @@ def build_connection_string(
) -> str:
if app_name:
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}?application_name={app_name}"
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
def init_sqlalchemy_engine(app_name: str) -> None:
global POSTGRES_APP_NAME
POSTGRES_APP_NAME = app_name
def get_sqlalchemy_engine() -> Engine:
global _SYNC_ENGINE
if _SYNC_ENGINE is None:
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=POSTGRES_APP_NAME + "_sync"
)
_SYNC_ENGINE = create_engine(
connection_string,
pool_size=5,
max_overflow=0,
pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE,
)
return _SYNC_ENGINE
return SqlEngine.get_engine()
def get_sqlalchemy_async_engine() -> AsyncEngine:
global _ASYNC_ENGINE
if _ASYNC_ENGINE is None:
# underlying asyncpg cannot accept application_name directly in the connection string
# Underlying asyncpg cannot accept application_name directly in the connection string
# https://github.com/MagicStack/asyncpg/issues/798
connection_string = build_connection_string()
_ASYNC_ENGINE = create_async_engine(
connection_string,
connect_args={
"server_settings": {"application_name": POSTGRES_APP_NAME + "_async"}
"server_settings": {
"application_name": SqlEngine.get_app_name() + "_async"
}
},
pool_size=5,
max_overflow=0,
# async engine is only used by API server, so we can use those values
# here as well
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE,
)
return _ASYNC_ENGINE
def get_session_context_manager() -> ContextManager[Session]:
return contextlib.contextmanager(get_session)()
# Context variable to store the current tenant ID
# This allows us to maintain tenant-specific context throughout the request lifecycle
# The default value is set to POSTGRES_DEFAULT_SCHEMA for non-multi-tenant setups
# This context variable works in both synchronous and asynchronous contexts
# In async code, it's automatically carried across coroutines
# In sync code, it's managed per thread
current_tenant_id = contextvars.ContextVar(
"current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA
)
def get_session() -> Generator[Session, None, None]:
# The line below was added to monitor the latency caused by Postgres connections
# during API calls.
# with tracer.trace("db.get_session"):
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session:
# Dependency to get the current tenant ID and set the context variable
def get_current_tenant_id(request: Request) -> str:
"""Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable."""
if not MULTI_TENANT:
tenant_id = POSTGRES_DEFAULT_SCHEMA
current_tenant_id.set(tenant_id)
return tenant_id
token = request.cookies.get("tenant_details")
if not token:
# If no token is present, use the default schema or handle accordingly
tenant_id = POSTGRES_DEFAULT_SCHEMA
current_tenant_id.set(tenant_id)
return tenant_id
try:
payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"])
tenant_id = payload.get("tenant_id")
if not tenant_id:
raise HTTPException(
status_code=400, detail="Invalid token: tenant_id missing"
)
if not is_valid_schema_name(tenant_id):
raise ValueError("Invalid tenant ID format")
current_tenant_id.set(tenant_id)
return tenant_id
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token format")
except ValueError as e:
# Let the 400 error bubble up
raise HTTPException(status_code=400, detail=str(e))
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")
def get_session_with_tenant(tenant_id: str | None = None) -> Session:
if tenant_id is None:
tenant_id = current_tenant_id.get()
if not is_valid_schema_name(tenant_id):
logger.error(f"Invalid tenant ID: {tenant_id}")
raise Exception("Invalid tenant ID")
engine = SqlEngine.get_engine()
session = Session(engine, expire_on_commit=False)
@event.listens_for(session, "after_begin")
def set_search_path(session: Session, transaction: Any, connection: Any) -> None:
connection.execute(text("SET search_path TO :schema"), {"schema": tenant_id})
return session
def get_session(
tenant_id: str = Depends(get_current_tenant_id),
) -> Generator[Session, None, None]:
"""Generate a database session with the appropriate tenant schema set."""
engine = get_sqlalchemy_engine()
with Session(engine, expire_on_commit=False) as session:
if MULTI_TENANT:
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Set the search_path to the tenant's schema
session.execute(text(f'SET search_path = "{tenant_id}"'))
yield session
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
async with AsyncSession(
get_sqlalchemy_async_engine(), expire_on_commit=False
) as async_session:
async def get_async_session(
tenant_id: str = Depends(get_current_tenant_id),
) -> AsyncGenerator[AsyncSession, None]:
"""Generate an async database session with the appropriate tenant schema set."""
engine = get_sqlalchemy_async_engine()
async with AsyncSession(engine, expire_on_commit=False) as async_session:
if MULTI_TENANT:
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Set the search_path to the tenant's schema
await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
yield async_session
def get_session_context_manager() -> ContextManager[Session]:
"""Context manager for database sessions."""
return contextlib.contextmanager(get_session)()
def get_session_factory() -> sessionmaker[Session]:
"""Get a session factory."""
global SessionFactory
if SessionFactory is None:
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
return SessionFactory
async def warm_up_connections(
sync_connections_to_warm_up: int = 20, async_connections_to_warm_up: int = 20
) -> None:
@@ -204,10 +356,3 @@ async def warm_up_connections(
await async_conn.execute(text("SELECT 1"))
for async_conn in async_connections:
await async_conn.close()
def get_session_factory() -> sessionmaker[Session]:
global SessionFactory
if SessionFactory is None:
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
return SessionFactory

View File

@@ -64,19 +64,12 @@ def upsert_cloud_embedding_provider(
def upsert_llm_provider(
llm_provider: LLMProviderUpsertRequest,
db_session: Session,
is_creation: bool = True,
) -> FullLLMProvider:
existing_llm_provider = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
)
if existing_llm_provider and is_creation:
raise ValueError(f"LLM Provider with name {llm_provider.name} already exists")
if not existing_llm_provider:
if not is_creation:
raise ValueError(
f"LLM Provider with name {llm_provider.name} does not exist"
)
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
db_session.add(existing_llm_provider)

View File

@@ -50,7 +50,7 @@ from danswer.db.enums import IndexingStatus
from danswer.db.enums import IndexModelStatus
from danswer.db.enums import TaskStatus
from danswer.db.pydantic_type import PydanticType
from danswer.dynamic_configs.interface import JSON_ro
from danswer.key_value_store.interface import JSON_ro
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
@@ -414,6 +414,12 @@ class ConnectorCredentialPair(Base):
last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), default=None
)
# last successful prune
last_pruned: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True, index=True
)
total_docs_indexed: Mapped[int] = mapped_column(Integer, default=0)
connector: Mapped["Connector"] = relationship(
@@ -1725,7 +1731,9 @@ class User__ExternalUserGroupId(Base):
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
# These group ids have been prefixed by the source type
external_user_group_id: Mapped[str] = mapped_column(String, primary_key=True)
cc_pair_id: Mapped[int] = mapped_column(ForeignKey("connector_credential_pair.id"))
cc_pair_id: Mapped[int] = mapped_column(
ForeignKey("connector_credential_pair.id"), primary_key=True
)
class UsageReport(Base):

View File

@@ -11,7 +11,7 @@ from danswer.db.index_attempt import (
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.search_settings import update_search_settings_status
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.key_value_store.factory import get_kv_store
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -54,7 +54,7 @@ def check_index_swap(db_session: Session) -> None:
)
if cc_pair_count > 0:
kv_store = get_dynamic_config_store()
kv_store = get_kv_store()
kv_store.store(KV_REINDEX_KEY, False)
# Expire jobs for the now past index/embedding model

View File

@@ -1,3 +1,4 @@
from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import or_
@@ -107,12 +108,14 @@ def create_or_add_document_tag_list(
return all_tags
def get_tags_by_value_prefix_for_source_types(
def find_tags(
tag_key_prefix: str | None,
tag_value_prefix: str | None,
sources: list[DocumentSource] | None,
limit: int | None,
db_session: Session,
# if set, both tag_key_prefix and tag_value_prefix must be a match
require_both_to_match: bool = False,
) -> list[Tag]:
query = select(Tag)
@@ -122,7 +125,11 @@ def get_tags_by_value_prefix_for_source_types(
conditions.append(Tag.tag_key.ilike(f"{tag_key_prefix}%"))
if tag_value_prefix:
conditions.append(Tag.tag_value.ilike(f"{tag_value_prefix}%"))
query = query.where(or_(*conditions))
final_prefix_condition = (
and_(*conditions) if require_both_to_match else or_(*conditions)
)
query = query.where(final_prefix_condition)
if sources:
query = query.where(Tag.source.in_(sources))

View File

@@ -1,3 +1,6 @@
from sqlalchemy.orm import Session
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.vespa.index import VespaIndex
@@ -13,3 +16,14 @@ def get_default_document_index(
return VespaIndex(
index_name=primary_index_name, secondary_index_name=secondary_index_name
)
def get_current_primary_default_document_index(db_session: Session) -> DocumentIndex:
"""
TODO: Use redis to cache this or something
"""
search_settings = get_current_search_settings(db_session)
return get_default_document_index(
primary_index_name=search_settings.index_name,
secondary_index_name=None,
)

View File

@@ -55,6 +55,21 @@ class DocumentMetadata:
from_ingestion_api: bool = False
@dataclass
class VespaDocumentFields:
"""
Specifies fields in Vespa for a document. Fields set to None will be ignored.
Perhaps we should name this in an implementation agnostic fashion, but it's more
understandable like this for now.
"""
# all other fields except these 4 will always be left alone by the update request
access: DocumentAccess | None = None
document_sets: set[str] | None = None
boost: float | None = None
hidden: bool | None = None
@dataclass
class UpdateRequest:
"""
@@ -156,6 +171,16 @@ class Deletable(abc.ABC):
Class must implement the ability to delete document by their unique document ids.
"""
@abc.abstractmethod
def delete_single(self, doc_id: str) -> int:
"""
Given a single document id, hard delete it from the document index
Parameters:
- doc_id: document id as specified by the connector
"""
raise NotImplementedError
@abc.abstractmethod
def delete(self, doc_ids: list[str]) -> None:
"""
@@ -178,11 +203,9 @@ class Updatable(abc.ABC):
"""
@abc.abstractmethod
def update_single(self, update_request: UpdateRequest) -> None:
def update_single(self, doc_id: str, fields: VespaDocumentFields) -> int:
"""
Updates some set of chunks for a document. The document and fields to update
are specified in the update request. Each update request in the list applies
its changes to a list of document ids.
Updates all chunks for a document with the specified fields.
None values mean that the field does not need an update.
The rationale for a single update function is that it allows retries and parallelism
@@ -190,14 +213,10 @@ class Updatable(abc.ABC):
us to individually handle error conditions per document.
Parameters:
- update_request: for a list of document ids in the update request, apply the same updates
to all of the documents with those ids.
- fields: the fields to update in the document. Any field set to None will not be changed.
Return:
- an HTTPStatus code. The code can used to decide whether to fail immediately,
retry, etc. Although this method likely hits an HTTP API behind the
scenes, the usage of HTTPStatus is a convenience and the interface is not
actually HTTP specific.
None
"""
raise NotImplementedError

View File

@@ -1,5 +1,6 @@
import concurrent.futures
import io
import logging
import os
import re
import time
@@ -13,6 +14,7 @@ from typing import cast
import httpx
import requests
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.chat_configs import DOC_TIME_DECAY
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.chat_configs import TITLE_CONTENT_RATIO
@@ -22,6 +24,7 @@ from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import DocumentInsertionRecord
from danswer.document_index.interfaces import UpdateRequest
from danswer.document_index.interfaces import VespaChunkRequest
from danswer.document_index.interfaces import VespaDocumentFields
from danswer.document_index.vespa.chunk_retrieval import batch_search_api_retrieval
from danswer.document_index.vespa.chunk_retrieval import (
get_all_vespa_ids_for_document_id,
@@ -58,8 +61,8 @@ from danswer.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT
from danswer.document_index.vespa_constants import VESPA_DIM_REPLACEMENT_PAT
from danswer.document_index.vespa_constants import VESPA_TIMEOUT
from danswer.document_index.vespa_constants import YQL_BASE
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.key_value_store.factory import get_kv_store
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunkUncleaned
from danswer.utils.batching import batch_generator
@@ -68,6 +71,10 @@ from shared_configs.model_server_models import Embedding
logger = setup_logger()
# Set the logging level to WARNING to ignore INFO and DEBUG logs
httpx_logger = logging.getLogger("httpx")
httpx_logger.setLevel(logging.WARNING)
@dataclass
class _VespaUpdateRequest:
@@ -140,7 +147,7 @@ class VespaIndex(DocumentIndex):
SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS)
)
kv_store = get_dynamic_config_store()
kv_store = get_kv_store()
needs_reindexing = False
try:
@@ -377,90 +384,89 @@ class VespaIndex(DocumentIndex):
time.monotonic() - update_start,
)
def update_single(self, update_request: UpdateRequest) -> None:
def update_single(self, doc_id: str, fields: VespaDocumentFields) -> int:
"""Note: if the document id does not exist, the update will be a no-op and the
function will complete with no errors or exceptions.
Handle other exceptions if you wish to implement retry behavior
"""
if len(update_request.document_ids) != 1:
raise ValueError("update_request must contain a single document id")
total_chunks_updated = 0
# Handle Vespa character limitations
# Mutating update_request but it's not used later anyway
update_request.document_ids = [
replace_invalid_doc_id_characters(doc_id)
for doc_id in update_request.document_ids
]
normalized_doc_id = replace_invalid_doc_id_characters(doc_id)
# update_start = time.monotonic()
# Build the _VespaUpdateRequest objects
update_dict: dict[str, dict] = {"fields": {}}
if fields.boost is not None:
update_dict["fields"][BOOST] = {"assign": fields.boost}
if fields.document_sets is not None:
update_dict["fields"][DOCUMENT_SETS] = {
"assign": {document_set: 1 for document_set in fields.document_sets}
}
if fields.access is not None:
update_dict["fields"][ACCESS_CONTROL_LIST] = {
"assign": {acl_entry: 1 for acl_entry in fields.access.to_acl()}
}
if fields.hidden is not None:
update_dict["fields"][HIDDEN] = {"assign": fields.hidden}
if not update_dict["fields"]:
logger.error("Update request received but nothing to update")
return 0
# Fetch all chunks for each document ahead of time
index_names = [self.index_name]
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
chunk_id_start_time = time.monotonic()
all_doc_chunk_ids: list[str] = []
for index_name in index_names:
for document_id in update_request.document_ids:
# this calls vespa and can raise http exceptions
doc_chunk_ids = get_all_vespa_ids_for_document_id(
document_id=document_id,
index_name=index_name,
filters=None,
get_large_chunks=True,
)
all_doc_chunk_ids.extend(doc_chunk_ids)
logger.debug(
f"Took {time.monotonic() - chunk_id_start_time:.2f} seconds to fetch all Vespa chunk IDs"
)
# Build the _VespaUpdateRequest objects
update_dict: dict[str, dict] = {"fields": {}}
if update_request.boost is not None:
update_dict["fields"][BOOST] = {"assign": update_request.boost}
if update_request.document_sets is not None:
update_dict["fields"][DOCUMENT_SETS] = {
"assign": {
document_set: 1 for document_set in update_request.document_sets
}
}
if update_request.access is not None:
update_dict["fields"][ACCESS_CONTROL_LIST] = {
"assign": {acl_entry: 1 for acl_entry in update_request.access.to_acl()}
}
if update_request.hidden is not None:
update_dict["fields"][HIDDEN] = {"assign": update_request.hidden}
if not update_dict["fields"]:
logger.error("Update request received but nothing to update")
return
processed_update_requests: list[_VespaUpdateRequest] = []
for document_id in update_request.document_ids:
for doc_chunk_id in all_doc_chunk_ids:
processed_update_requests.append(
_VespaUpdateRequest(
document_id=document_id,
url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}/{doc_chunk_id}",
update_request=update_dict,
)
)
with httpx.Client(http2=True) as http_client:
for update in processed_update_requests:
http_client.put(
update.url,
headers={"Content-Type": "application/json"},
json=update.update_request,
for index_name in index_names:
params = httpx.QueryParams(
{
"selection": f"{index_name}.document_id=='{normalized_doc_id}'",
"cluster": DOCUMENT_INDEX_NAME,
}
)
# logger.debug(
# "Finished updating Vespa documents in %.2f seconds",
# time.monotonic() - update_start,
# )
while True:
try:
resp = http_client.put(
f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}",
params=params,
headers={"Content-Type": "application/json"},
json=update_dict,
)
return
resp.raise_for_status()
except httpx.HTTPStatusError as e:
logger.error(
f"Failed to update chunks, details: {e.response.text}"
)
raise
resp_data = resp.json()
if "documentCount" in resp_data:
chunks_updated = resp_data["documentCount"]
total_chunks_updated += chunks_updated
# Check for continuation token to handle pagination
if "continuation" not in resp_data:
break # Exit loop if no continuation token
if not resp_data["continuation"]:
break # Exit loop if continuation token is empty
params = params.set("continuation", resp_data["continuation"])
logger.debug(
f"VespaIndex.update_single: "
f"index={index_name} "
f"doc={normalized_doc_id} "
f"chunks_updated={total_chunks_updated}"
)
return total_chunks_updated
def delete(self, doc_ids: list[str]) -> None:
logger.info(f"Deleting {len(doc_ids)} documents from Vespa")
@@ -478,6 +484,70 @@ class VespaIndex(DocumentIndex):
delete_vespa_docs(
document_ids=doc_ids, index_name=index_name, http_client=http_client
)
return
def delete_single(self, doc_id: str) -> int:
"""Possibly faster overall than the delete method due to using a single
delete call with a selection query."""
total_chunks_deleted = 0
# Vespa deletion is poorly documented ... luckily we found this
# https://docs.vespa.ai/en/operations/batch-delete.html#example
doc_id = replace_invalid_doc_id_characters(doc_id)
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficial for
# indexing / updates / deletes since we have to make a large volume of requests.
index_names = [self.index_name]
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
with httpx.Client(http2=True) as http_client:
for index_name in index_names:
params = httpx.QueryParams(
{
"selection": f"{index_name}.document_id=='{doc_id}'",
"cluster": DOCUMENT_INDEX_NAME,
}
)
while True:
try:
resp = http_client.delete(
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}",
params=params,
)
resp.raise_for_status()
except httpx.HTTPStatusError as e:
logger.error(
f"Failed to delete chunk, details: {e.response.text}"
)
raise
resp_data = resp.json()
if "documentCount" in resp_data:
chunks_deleted = resp_data["documentCount"]
total_chunks_deleted += chunks_deleted
# Check for continuation token to handle pagination
if "continuation" not in resp_data:
break # Exit loop if no continuation token
if not resp_data["continuation"]:
break # Exit loop if continuation token is empty
params = params.set("continuation", resp_data["continuation"])
logger.debug(
f"VespaIndex.delete_single: "
f"index={index_name} "
f"doc={doc_id} "
f"chunks_deleted={total_chunks_deleted}"
)
return total_chunks_deleted
def id_based_retrieval(
self,

View File

@@ -1,15 +0,0 @@
from danswer.configs.app_configs import DYNAMIC_CONFIG_STORE
from danswer.dynamic_configs.interface import DynamicConfigStore
from danswer.dynamic_configs.store import FileSystemBackedDynamicConfigStore
from danswer.dynamic_configs.store import PostgresBackedDynamicConfigStore
def get_dynamic_config_store() -> DynamicConfigStore:
dynamic_config_store_type = DYNAMIC_CONFIG_STORE
if dynamic_config_store_type == FileSystemBackedDynamicConfigStore.__name__:
raise NotImplementedError("File based config store no longer supported")
if dynamic_config_store_type == PostgresBackedDynamicConfigStore.__name__:
return PostgresBackedDynamicConfigStore()
# TODO: change exception type
raise Exception("Unknown dynamic config store type")

View File

@@ -1,102 +0,0 @@
import json
import os
from collections.abc import Iterator
from contextlib import contextmanager
from pathlib import Path
from typing import cast
from filelock import FileLock
from sqlalchemy.orm import Session
from danswer.db.engine import get_session_factory
from danswer.db.models import KVStore
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.dynamic_configs.interface import DynamicConfigStore
from danswer.dynamic_configs.interface import JSON_ro
FILE_LOCK_TIMEOUT = 10
def _get_file_lock(file_name: Path) -> FileLock:
return FileLock(file_name.with_suffix(".lock"))
class FileSystemBackedDynamicConfigStore(DynamicConfigStore):
def __init__(self, dir_path: str) -> None:
# TODO (chris): maybe require all possible keys to be passed in
# at app start somehow to prevent key overlaps
self.dir_path = Path(dir_path)
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
file_path = self.dir_path / key
lock = _get_file_lock(file_path)
with lock.acquire(timeout=FILE_LOCK_TIMEOUT):
with open(file_path, "w+") as f:
json.dump(val, f)
def load(self, key: str) -> JSON_ro:
file_path = self.dir_path / key
if not file_path.exists():
raise ConfigNotFoundError
lock = _get_file_lock(file_path)
with lock.acquire(timeout=FILE_LOCK_TIMEOUT):
with open(self.dir_path / key) as f:
return cast(JSON_ro, json.load(f))
def delete(self, key: str) -> None:
file_path = self.dir_path / key
if not file_path.exists():
raise ConfigNotFoundError
lock = _get_file_lock(file_path)
with lock.acquire(timeout=FILE_LOCK_TIMEOUT):
os.remove(file_path)
class PostgresBackedDynamicConfigStore(DynamicConfigStore):
@contextmanager
def get_session(self) -> Iterator[Session]:
factory = get_session_factory()
session: Session = factory()
try:
yield session
finally:
session.close()
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
# The actual encryption/decryption is done in Postgres, we just need to choose
# which field to set
encrypted_val = val if encrypt else None
plain_val = val if not encrypt else None
with self.get_session() as session:
obj = session.query(KVStore).filter_by(key=key).first()
if obj:
obj.value = plain_val
obj.encrypted_value = encrypted_val
else:
obj = KVStore(
key=key, value=plain_val, encrypted_value=encrypted_val
) # type: ignore
session.query(KVStore).filter_by(key=key).delete() # just in case
session.add(obj)
session.commit()
def load(self, key: str) -> JSON_ro:
with self.get_session() as session:
obj = session.query(KVStore).filter_by(key=key).first()
if not obj:
raise ConfigNotFoundError
if obj.value is not None:
return cast(JSON_ro, obj.value)
if obj.encrypted_value is not None:
return cast(JSON_ro, obj.encrypted_value)
return None
def delete(self, key: str) -> None:
with self.get_session() as session:
result = session.query(KVStore).filter_by(key=key).delete() # type: ignore
if result == 0:
raise ConfigNotFoundError
session.commit()

View File

@@ -20,6 +20,8 @@ from pypdf.errors import PdfStreamError
from danswer.configs.constants import DANSWER_METADATA_FILENAME
from danswer.file_processing.html_utils import parse_html_page_basic
from danswer.file_processing.unstructured import get_unstructured_api_key
from danswer.file_processing.unstructured import unstructured_to_text
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -331,9 +333,10 @@ def file_io_to_text(file: IO[Any]) -> str:
def extract_file_text(
file_name: str | None,
file: IO[Any],
file_name: str,
break_on_unprocessable: bool = True,
extension: str | None = None,
) -> str:
extension_to_function: dict[str, Callable[[IO[Any]], str]] = {
".pdf": pdf_to_text,
@@ -345,22 +348,29 @@ def extract_file_text(
".html": parse_html_page_basic,
}
def _process_file() -> str:
if file_name:
extension = get_file_ext(file_name)
if check_file_ext_is_valid(extension):
return extension_to_function.get(extension, file_io_to_text)(file)
try:
if get_unstructured_api_key():
return unstructured_to_text(file, file_name)
# Either the file somehow has no name or the extension is not one that we are familiar with
if file_name or extension:
if extension is not None:
final_extension = extension
elif file_name is not None:
final_extension = get_file_ext(file_name)
if check_file_ext_is_valid(final_extension):
return extension_to_function.get(final_extension, file_io_to_text)(file)
# Either the file somehow has no name or the extension is not one that we recognize
if is_text_file(file):
return file_io_to_text(file)
raise ValueError("Unknown file extension and unknown text encoding")
try:
return _process_file()
except Exception as e:
if break_on_unprocessable:
raise RuntimeError(f"Failed to process file: {str(e)}") from e
logger.warning(f"Failed to process file: {str(e)}")
raise RuntimeError(
f"Failed to process file {file_name or 'Unknown'}: {str(e)}"
) from e
logger.warning(f"Failed to process file {file_name or 'Unknown'}: {str(e)}")
return ""

View File

@@ -0,0 +1,67 @@
from typing import Any
from typing import cast
from typing import IO
from unstructured.staging.base import dict_to_elements
from unstructured_client import UnstructuredClient # type: ignore
from unstructured_client.models import operations # type: ignore
from unstructured_client.models import shared
from danswer.configs.constants import KV_UNSTRUCTURED_API_KEY
from danswer.key_value_store.factory import get_kv_store
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_unstructured_api_key() -> str | None:
kv_store = get_kv_store()
try:
return cast(str, kv_store.load(KV_UNSTRUCTURED_API_KEY))
except KvKeyNotFoundError:
return None
def update_unstructured_api_key(api_key: str) -> None:
kv_store = get_kv_store()
kv_store.store(KV_UNSTRUCTURED_API_KEY, api_key)
def delete_unstructured_api_key() -> None:
kv_store = get_kv_store()
kv_store.delete(KV_UNSTRUCTURED_API_KEY)
def _sdk_partition_request(
file: IO[Any], file_name: str, **kwargs: Any
) -> operations.PartitionRequest:
try:
request = operations.PartitionRequest(
partition_parameters=shared.PartitionParameters(
files=shared.Files(content=file.read(), file_name=file_name),
**kwargs,
),
)
return request
except Exception as e:
logger.error(f"Error creating partition request for file {file_name}: {str(e)}")
raise
def unstructured_to_text(file: IO[Any], file_name: str) -> str:
logger.debug(f"Starting to read file: {file_name}")
req = _sdk_partition_request(file, file_name, strategy="auto")
unstructured_client = UnstructuredClient(api_key_auth=get_unstructured_api_key())
response = unstructured_client.general.partition(req) # type: ignore
elements = dict_to_elements(response.elements)
if response.status_code != 200:
err = f"Received unexpected status code {response.status_code} from Unstructured API."
logger.error(err)
raise ValueError(err)
return "\n\n".join(str(el) for el in elements)

View File

@@ -10,6 +10,7 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
get_metadata_keys_to_ignore,
)
from danswer.connectors.models import Document
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.models import DocAwareChunk
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.utils.logger import setup_logger
@@ -123,6 +124,7 @@ class Chunker:
chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE,
chunk_overlap: int = CHUNK_OVERLAP,
mini_chunk_size: int = MINI_CHUNK_SIZE,
heartbeat: Heartbeat | None = None,
) -> None:
from llama_index.text_splitter import SentenceSplitter
@@ -131,6 +133,7 @@ class Chunker:
self.enable_multipass = enable_multipass
self.enable_large_chunks = enable_large_chunks
self.tokenizer = tokenizer
self.heartbeat = heartbeat
self.blurb_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize,
@@ -255,7 +258,7 @@ class Chunker:
# If the chunk does not have any useable content, it will not be indexed
return chunks
def chunk(self, document: Document) -> list[DocAwareChunk]:
def _handle_single_document(self, document: Document) -> list[DocAwareChunk]:
# Specifically for reproducing an issue with gmail
if document.source == DocumentSource.GMAIL:
logger.debug(f"Chunking {document.semantic_identifier}")
@@ -302,3 +305,13 @@ class Chunker:
normal_chunks.extend(large_chunks)
return normal_chunks
def chunk(self, documents: list[Document]) -> list[DocAwareChunk]:
final_chunks: list[DocAwareChunk] = []
for document in documents:
final_chunks.extend(self._handle_single_document(document))
if self.heartbeat:
self.heartbeat.heartbeat()
return final_chunks

View File

@@ -1,12 +1,8 @@
from abc import ABC
from abc import abstractmethod
from sqlalchemy.orm import Session
from danswer.db.models import IndexModelStatus
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
@@ -24,6 +20,9 @@ logger = setup_logger()
class IndexingEmbedder(ABC):
"""Converts chunks into chunks with embeddings. Note that one chunk may have
multiple embeddings associated with it."""
def __init__(
self,
model_name: str,
@@ -33,6 +32,7 @@ class IndexingEmbedder(ABC):
provider_type: EmbeddingProvider | None,
api_key: str | None,
api_url: str | None,
heartbeat: Heartbeat | None,
):
self.model_name = model_name
self.normalize = normalize
@@ -54,6 +54,7 @@ class IndexingEmbedder(ABC):
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
retrim_content=True,
heartbeat=heartbeat,
)
@abstractmethod
@@ -74,6 +75,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type: EmbeddingProvider | None = None,
api_key: str | None = None,
api_url: str | None = None,
heartbeat: Heartbeat | None = None,
):
super().__init__(
model_name,
@@ -83,6 +85,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type,
api_key,
api_url,
heartbeat,
)
@log_function_time()
@@ -166,7 +169,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
title_embed_dict[title] = title_embedding
new_embedded_chunk = IndexChunk(
**chunk.dict(),
**chunk.model_dump(),
embeddings=ChunkEmbedding(
full_embedding=chunk_embeddings[0],
mini_chunk_embeddings=chunk_embeddings[1:],
@@ -180,7 +183,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
@classmethod
def from_db_search_settings(
cls, search_settings: SearchSettings
cls, search_settings: SearchSettings, heartbeat: Heartbeat | None = None
) -> "DefaultIndexingEmbedder":
return cls(
model_name=search_settings.model_name,
@@ -190,28 +193,5 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type=search_settings.provider_type,
api_key=search_settings.api_key,
api_url=search_settings.api_url,
heartbeat=heartbeat,
)
def get_embedding_model_from_search_settings(
db_session: Session, index_model_status: IndexModelStatus = IndexModelStatus.PRESENT
) -> IndexingEmbedder:
search_settings: SearchSettings | None
if index_model_status == IndexModelStatus.PRESENT:
search_settings = get_current_search_settings(db_session)
elif index_model_status == IndexModelStatus.FUTURE:
search_settings = get_secondary_search_settings(db_session)
if not search_settings:
raise RuntimeError("No secondary index configured")
else:
raise RuntimeError("Not supporting embedding model rollbacks")
return DefaultIndexingEmbedder(
model_name=search_settings.model_name,
normalize=search_settings.normalize,
query_prefix=search_settings.query_prefix,
passage_prefix=search_settings.passage_prefix,
provider_type=search_settings.provider_type,
api_key=search_settings.api_key,
api_url=search_settings.api_url,
)

View File

@@ -0,0 +1,41 @@
import abc
from typing import Any
from sqlalchemy import func
from sqlalchemy.orm import Session
from danswer.db.index_attempt import get_index_attempt
from danswer.utils.logger import setup_logger
logger = setup_logger()
class Heartbeat(abc.ABC):
"""Useful for any long-running work that goes through a bunch of items
and needs to occasionally give updates on progress.
e.g. chunking, embedding, updating vespa, etc."""
@abc.abstractmethod
def heartbeat(self, metadata: Any = None) -> None:
raise NotImplementedError
class IndexingHeartbeat(Heartbeat):
def __init__(self, index_attempt_id: int, db_session: Session, freq: int):
self.cnt = 0
self.index_attempt_id = index_attempt_id
self.db_session = db_session
self.freq = freq
def heartbeat(self, metadata: Any = None) -> None:
self.cnt += 1
if self.cnt % self.freq == 0:
index_attempt = get_index_attempt(
db_session=self.db_session, index_attempt_id=self.index_attempt_id
)
if index_attempt:
index_attempt.time_updated = func.now()
self.db_session.commit()
else:
logger.error("Index attempt not found, this should not happen!")

View File

@@ -31,6 +31,7 @@ from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import DocumentMetadata
from danswer.indexing.chunker import Chunker
from danswer.indexing.embedder import IndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.utils.logger import setup_logger
@@ -220,8 +221,8 @@ def index_doc_batch_prepare(
document_ids = [document.id for document in documents]
db_docs: list[DBDocument] = get_documents_by_ids(
document_ids=document_ids,
db_session=db_session,
document_ids=document_ids,
)
# Skip indexing docs that don't have a newer updated at
@@ -283,18 +284,10 @@ def index_doc_batch(
return 0, 0
logger.debug("Starting chunking")
chunks: list[DocAwareChunk] = []
for document in ctx.updatable_docs:
chunks.extend(chunker.chunk(document=document))
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
logger.debug("Starting embedding")
chunks_with_embeddings = (
embedder.embed_chunks(
chunks=chunks,
)
if chunks
else []
)
chunks_with_embeddings = embedder.embed_chunks(chunks) if chunks else []
updatable_ids = [doc.id for doc in ctx.updatable_docs]
@@ -406,6 +399,13 @@ def build_indexing_pipeline(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=multipass,
enable_large_chunks=enable_large_chunks,
# after every doc, update status in case there are a bunch of
# really long docs
heartbeat=IndexingHeartbeat(
index_attempt_id=attempt_id, db_session=db_session, freq=1
)
if attempt_id
else None,
)
return partial(

View File

@@ -0,0 +1,7 @@
from danswer.key_value_store.interface import KeyValueStore
from danswer.key_value_store.store import PgRedisKVStore
def get_kv_store() -> KeyValueStore:
# this is the only one supported currently
return PgRedisKVStore()

View File

@@ -9,11 +9,11 @@ JSON_ro: TypeAlias = (
)
class ConfigNotFoundError(Exception):
class KvKeyNotFoundError(Exception):
pass
class DynamicConfigStore:
class KeyValueStore:
@abc.abstractmethod
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
raise NotImplementedError

View File

@@ -0,0 +1,96 @@
import json
from collections.abc import Iterator
from contextlib import contextmanager
from typing import cast
from sqlalchemy.orm import Session
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import KVStore
from danswer.key_value_store.interface import JSON_ro
from danswer.key_value_store.interface import KeyValueStore
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
logger = setup_logger()
REDIS_KEY_PREFIX = "danswer_kv_store:"
KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
class PgRedisKVStore(KeyValueStore):
def __init__(self) -> None:
self.redis_client = get_redis_client()
@contextmanager
def get_session(self) -> Iterator[Session]:
engine = get_sqlalchemy_engine()
with Session(engine, expire_on_commit=False) as session:
yield session
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
# Not encrypted in Redis, but encrypted in Postgres
try:
self.redis_client.set(
REDIS_KEY_PREFIX + key, json.dumps(val), ex=KV_REDIS_KEY_EXPIRATION
)
except Exception as e:
# Fallback gracefully to Postgres if Redis fails
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
encrypted_val = val if encrypt else None
plain_val = val if not encrypt else None
with self.get_session() as session:
obj = session.query(KVStore).filter_by(key=key).first()
if obj:
obj.value = plain_val
obj.encrypted_value = encrypted_val
else:
obj = KVStore(
key=key, value=plain_val, encrypted_value=encrypted_val
) # type: ignore
session.query(KVStore).filter_by(key=key).delete() # just in case
session.add(obj)
session.commit()
def load(self, key: str) -> JSON_ro:
try:
redis_value = self.redis_client.get(REDIS_KEY_PREFIX + key)
if redis_value:
assert isinstance(redis_value, bytes)
return json.loads(redis_value.decode("utf-8"))
except Exception as e:
logger.error(f"Failed to get value from Redis for key '{key}': {str(e)}")
with self.get_session() as session:
obj = session.query(KVStore).filter_by(key=key).first()
if not obj:
raise KvKeyNotFoundError
if obj.value is not None:
value = obj.value
elif obj.encrypted_value is not None:
value = obj.encrypted_value
else:
value = None
try:
self.redis_client.set(REDIS_KEY_PREFIX + key, json.dumps(value))
except Exception as e:
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
return cast(JSON_ro, value)
def delete(self, key: str) -> None:
try:
self.redis_client.delete(REDIS_KEY_PREFIX + key)
except Exception as e:
logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}")
with self.get_session() as session:
result = session.query(KVStore).filter_by(key=key).delete() # type: ignore
if result == 0:
raise KvKeyNotFoundError
session.commit()

View File

@@ -1,3 +1,4 @@
import itertools
from collections.abc import Callable
from collections.abc import Iterator
from typing import Any
@@ -310,13 +311,15 @@ class Answer:
)
)
yield tool_runner.tool_final_result()
if not self.skip_gen_ai_answer_generation:
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
yield from self._process_llm_stream(
prompt=prompt,
tools=[tool.tool_definition() for tool in self.tools],
)
yield from self._process_llm_stream(
prompt=prompt,
# as of now, we don't support multiple tool calls in sequence, which is why
# we don't need to pass this in here
# tools=[tool.tool_definition() for tool in self.tools],
)
return
@@ -412,6 +415,10 @@ class Answer:
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
if not chosen_tool_and_args:
if self.skip_gen_ai_answer_generation:
raise ValueError(
"skip_gen_ai_answer_generation is True, but no tool was chosen; no answer will be generated"
)
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
@@ -476,10 +483,10 @@ class Answer:
final = tool_runner.tool_final_result()
yield final
if not self.skip_gen_ai_answer_generation:
prompt = prompt_builder.build()
prompt = prompt_builder.build()
yield from self._process_llm_stream(prompt=prompt, tools=None)
yield from self._process_llm_stream(prompt=prompt, tools=None)
@property
def processed_streamed_output(self) -> AnswerStream:
@@ -554,8 +561,7 @@ class Answer:
def _stream() -> Iterator[str]:
nonlocal stream_stop_info
yield cast(str, message)
for item in stream:
for item in itertools.chain([message], stream):
if isinstance(item, StreamStopInfo):
stream_stop_info = item
return

View File

@@ -18,6 +18,7 @@ from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT
from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT
from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT_FOR_TOOL_CALLING
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
from danswer.prompts.prompt_utils import add_date_time_to_prompt
from danswer.prompts.prompt_utils import build_complete_context_str
from danswer.prompts.prompt_utils import build_task_prompt_reminders
@@ -143,6 +144,12 @@ def build_citations_user_message(
prompt=prompt_config, use_language_hint=bool(multilingual_expansion)
)
history_block = (
HISTORY_BLOCK.format(history_str=history_message) + "\n"
if history_message
else ""
)
if context_docs:
context_docs_str = build_complete_context_str(context_docs)
optional_ignore = "" if all_doc_useful else DEFAULT_IGNORE_STATEMENT
@@ -152,14 +159,14 @@ def build_citations_user_message(
context_docs_str=context_docs_str,
task_prompt=task_prompt_with_reminder,
user_query=question,
history_block=history_message,
history_block=history_block,
)
else:
# if no context docs provided, assume we're in the tool calling flow
user_prompt = CITATIONS_PROMPT_FOR_TOOL_CALLING.format(
task_prompt=task_prompt_with_reminder,
user_query=question,
history_block=history_message,
history_block=history_block,
)
user_prompt = user_prompt.strip()

View File

@@ -283,6 +283,7 @@ class DefaultMultiLLM(LLM):
_convert_message_to_dict(msg) if isinstance(msg, BaseMessage) else msg
for msg in prompt
]
elif isinstance(prompt, str):
prompt = [_convert_message_to_dict(HumanMessage(content=prompt))]
@@ -290,10 +291,12 @@ class DefaultMultiLLM(LLM):
return litellm.completion(
# model choice
model=f"{self.config.model_provider}/{self.config.model_name}",
api_key=self._api_key,
base_url=self._api_base,
api_version=self._api_version,
custom_llm_provider=self._custom_llm_provider,
# NOTE: have to pass in None instead of empty string for these
# otherwise litellm can have some issues with bedrock
api_key=self._api_key or None,
base_url=self._api_base or None,
api_version=self._api_version or None,
custom_llm_provider=self._custom_llm_provider or None,
# actual input
messages=prompt,
tools=tools,

View File

@@ -1,4 +1,3 @@
import time
import traceback
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
@@ -23,55 +22,22 @@ from danswer.auth.schemas import UserRead
from danswer.auth.schemas import UserUpdate
from danswer.auth.users import auth_backend
from danswer.auth.users import fastapi_users
from danswer.chat.load_yamls import load_chat_yamls
from danswer.configs.app_configs import APP_API_PREFIX
from danswer.configs.app_configs import APP_HOST
from danswer.configs.app_configs import APP_PORT
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.app_configs import LOG_ENDPOINT_LATENCY
from danswer.configs.app_configs import OAUTH_CLIENT_ID
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
from danswer.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
from danswer.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import AuthType
from danswer.configs.constants import KV_REINDEX_KEY
from danswer.configs.constants import KV_SEARCH_SETTINGS
from danswer.configs.constants import POSTGRES_WEB_APP_NAME
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.db.connector import check_connectors_exist
from danswer.db.connector import create_initial_default_connector
from danswer.db.connector_credential_pair import associate_default_cc_pair
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.connector_credential_pair import resync_cc_pair
from danswer.db.credentials import create_initial_public_credential
from danswer.db.document import check_docs_exist
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import init_sqlalchemy_engine
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import expire_index_attempts
from danswer.db.llm import fetch_default_provider
from danswer.db.llm import update_default_provider
from danswer.db.llm import upsert_llm_provider
from danswer.db.persona import delete_old_default_personas
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.search_settings import update_current_search_settings
from danswer.db.search_settings import update_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import DocumentIndex
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.indexing.models import IndexingSetting
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
from danswer.search.models import SavedSearchSettings
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.db.engine import SqlEngine
from danswer.db.engine import warm_up_connections
from danswer.server.auth_check import check_router_auth
from danswer.server.danswer_api.ingestion import router as danswer_api_router
from danswer.server.documents.cc_pair import router as cc_pair_router
@@ -97,7 +63,6 @@ from danswer.server.manage.embedding.api import basic_router as embedding_router
from danswer.server.manage.get_state import router as state_router
from danswer.server.manage.llm.api import admin_router as llm_admin_router
from danswer.server.manage.llm.api import basic_router as llm_router
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from danswer.server.manage.search_settings import router as search_settings_router
from danswer.server.manage.slack_bot import router as slack_bot_management_router
from danswer.server.manage.users import router as user_router
@@ -109,15 +74,10 @@ from danswer.server.query_and_chat.query_backend import (
from danswer.server.query_and_chat.query_backend import basic_router as query_router
from danswer.server.settings.api import admin_router as settings_admin_router
from danswer.server.settings.api import basic_router as settings_router
from danswer.server.settings.store import load_settings
from danswer.server.settings.store import store_settings
from danswer.server.token_rate_limits.api import (
router as token_rate_limit_settings_router,
)
from danswer.tools.built_in_tools import auto_add_search_tool_to_personas
from danswer.tools.built_in_tools import load_builtin_tools
from danswer.tools.built_in_tools import refresh_built_in_tools_cache
from danswer.utils.gpu_utils import gpu_status_request
from danswer.setup import setup_danswer
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import get_or_generate_uuid
from danswer.utils.telemetry import optional_telemetry
@@ -126,8 +86,6 @@ from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import CORS_ALLOWED_ORIGIN
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
logger = setup_logger()
@@ -180,180 +138,14 @@ def include_router_with_global_prefix_prepended(
application.include_router(router, **final_kwargs)
def setup_postgres(db_session: Session) -> None:
logger.notice("Verifying default connector/credential exist.")
create_initial_public_credential(db_session)
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
logger.notice("Loading default Prompts and Personas")
delete_old_default_personas(db_session)
load_chat_yamls()
logger.notice("Loading built-in tools")
load_builtin_tools(db_session)
refresh_built_in_tools_cache(db_session)
auto_add_search_tool_to_personas(db_session)
if GEN_AI_API_KEY and fetch_default_provider(db_session) is None:
# Only for dev flows
logger.notice("Setting up default OpenAI LLM for dev.")
llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini"
fast_model = FAST_GEN_AI_MODEL_VERSION or "gpt-4o-mini"
model_req = LLMProviderUpsertRequest(
name="DevEnvPresetOpenAI",
provider="openai",
api_key=GEN_AI_API_KEY,
api_base=None,
api_version=None,
custom_config=None,
default_model_name=llm_model,
fast_default_model_name=fast_model,
is_public=True,
groups=[],
display_model_names=[llm_model, fast_model],
model_names=[llm_model, fast_model],
)
new_llm_provider = upsert_llm_provider(
llm_provider=model_req, db_session=db_session
)
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
def update_default_multipass_indexing(db_session: Session) -> None:
docs_exist = check_docs_exist(db_session)
connectors_exist = check_connectors_exist(db_session)
logger.debug(f"Docs exist: {docs_exist}, Connectors exist: {connectors_exist}")
if not docs_exist and not connectors_exist:
logger.info(
"No existing docs or connectors found. Checking GPU availability for multipass indexing."
)
gpu_available = gpu_status_request()
logger.info(f"GPU available: {gpu_available}")
current_settings = get_current_search_settings(db_session)
logger.notice(f"Updating multipass indexing setting to: {gpu_available}")
updated_settings = SavedSearchSettings.from_db_model(current_settings)
# Enable multipass indexing if GPU is available or if using a cloud provider
updated_settings.multipass_indexing = (
gpu_available or current_settings.cloud_provider is not None
)
update_current_search_settings(db_session, updated_settings)
# Update settings with GPU availability
settings = load_settings()
settings.gpu_enabled = gpu_available
store_settings(settings)
logger.notice(f"Updated settings with GPU availability: {gpu_available}")
else:
logger.debug(
"Existing docs or connectors found. Skipping multipass indexing update."
)
def translate_saved_search_settings(db_session: Session) -> None:
kv_store = get_dynamic_config_store()
try:
search_settings_dict = kv_store.load(KV_SEARCH_SETTINGS)
if isinstance(search_settings_dict, dict):
# Update current search settings
current_settings = get_current_search_settings(db_session)
# Update non-preserved fields
if current_settings:
current_settings_dict = SavedSearchSettings.from_db_model(
current_settings
).dict()
new_current_settings = SavedSearchSettings(
**{**current_settings_dict, **search_settings_dict}
)
update_current_search_settings(db_session, new_current_settings)
# Update secondary search settings
secondary_settings = get_secondary_search_settings(db_session)
if secondary_settings:
secondary_settings_dict = SavedSearchSettings.from_db_model(
secondary_settings
).dict()
new_secondary_settings = SavedSearchSettings(
**{**secondary_settings_dict, **search_settings_dict}
)
update_secondary_search_settings(
db_session,
new_secondary_settings,
)
# Delete the KV store entry after successful update
kv_store.delete(KV_SEARCH_SETTINGS)
logger.notice("Search settings updated and KV store entry deleted.")
else:
logger.notice("KV store search settings is empty.")
except ConfigNotFoundError:
logger.notice("No search config found in KV store.")
def mark_reindex_flag(db_session: Session) -> None:
kv_store = get_dynamic_config_store()
try:
value = kv_store.load(KV_REINDEX_KEY)
logger.debug(f"Re-indexing flag has value {value}")
return
except ConfigNotFoundError:
# Only need to update the flag if it hasn't been set
pass
# If their first deployment is after the changes, it will
# enable this when the other changes go in, need to avoid
# this being set to False, then the user indexes things on the old version
docs_exist = check_docs_exist(db_session)
connectors_exist = check_connectors_exist(db_session)
if docs_exist or connectors_exist:
kv_store.store(KV_REINDEX_KEY, True)
else:
kv_store.store(KV_REINDEX_KEY, False)
def setup_vespa(
document_index: DocumentIndex,
index_setting: IndexingSetting,
secondary_index_setting: IndexingSetting | None,
) -> bool:
# Vespa startup is a bit slow, so give it a few seconds
WAIT_SECONDS = 5
VESPA_ATTEMPTS = 5
for x in range(VESPA_ATTEMPTS):
try:
logger.notice(f"Setting up Vespa (attempt {x+1}/{VESPA_ATTEMPTS})...")
document_index.ensure_indices_exist(
index_embedding_dim=index_setting.model_dim,
secondary_index_embedding_dim=secondary_index_setting.model_dim
if secondary_index_setting
else None,
)
logger.notice("Vespa setup complete.")
return True
except Exception:
logger.notice(
f"Vespa setup did not succeed. The Vespa service may not be ready yet. Retrying in {WAIT_SECONDS} seconds."
)
time.sleep(WAIT_SECONDS)
logger.error(
f"Vespa setup did not succeed. Attempt limit reached. ({VESPA_ATTEMPTS})"
)
return False
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
init_sqlalchemy_engine(POSTGRES_WEB_APP_NAME)
engine = get_sqlalchemy_engine()
SqlEngine.set_app_name(POSTGRES_WEB_APP_NAME)
SqlEngine.init_engine(
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
)
engine = SqlEngine.get_engine()
verify_auth = fetch_versioned_implementation(
"danswer.auth.users", "verify_auth_setting"
@@ -368,95 +160,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
logger.notice("Generative AI Q&A disabled")
# fill up Postgres connection pools
# await warm_up_connections()
await warm_up_connections()
# We cache this at the beginning so there is no delay in the first telemetry
get_or_generate_uuid()
with Session(engine) as db_session:
check_index_swap(db_session=db_session)
search_settings = get_current_search_settings(db_session)
secondary_search_settings = get_secondary_search_settings(db_session)
# Break bad state for thrashing indexes
if secondary_search_settings and DISABLE_INDEX_UPDATE_ON_SWAP:
expire_index_attempts(
search_settings_id=search_settings.id, db_session=db_session
)
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(cc_pair, db_session=db_session)
# Expire all old embedding models indexing attempts, technically redundant
cancel_indexing_attempts_past_model(db_session)
logger.notice(f'Using Embedding model: "{search_settings.model_name}"')
if search_settings.query_prefix or search_settings.passage_prefix:
logger.notice(f'Query embedding prefix: "{search_settings.query_prefix}"')
logger.notice(
f'Passage embedding prefix: "{search_settings.passage_prefix}"'
)
if search_settings:
if not search_settings.disable_rerank_for_streaming:
logger.notice("Reranking is enabled.")
if search_settings.multilingual_expansion:
logger.notice(
f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}."
)
if (
search_settings.rerank_model_name
and not search_settings.provider_type
and not search_settings.rerank_provider_type
):
warm_up_cross_encoder(search_settings.rerank_model_name)
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
download_nltk_data()
# setup Postgres with default credential, llm providers, etc.
setup_postgres(db_session)
translate_saved_search_settings(db_session)
# Does the user need to trigger a reindexing to bring the document index
# into a good state, marked in the kv store
mark_reindex_flag(db_session)
# ensure Vespa is setup correctly
logger.notice("Verifying Document Index(s) is/are available.")
document_index = get_default_document_index(
primary_index_name=search_settings.index_name,
secondary_index_name=secondary_search_settings.index_name
if secondary_search_settings
else None,
)
success = setup_vespa(
document_index,
IndexingSetting.from_db_model(search_settings),
IndexingSetting.from_db_model(secondary_search_settings)
if secondary_search_settings
else None,
)
if not success:
raise RuntimeError(
"Could not connect to Vespa within the specified timeout."
)
logger.notice(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
if search_settings.provider_type is None:
warm_up_bi_encoder(
embedding_model=EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
),
)
# update multipass indexing setting based on GPU availability
update_default_multipass_indexing(db_session)
setup_danswer(db_session)
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
yield
@@ -619,7 +329,7 @@ if __name__ == "__main__":
f"Starting Danswer Backend version {__version__} on http://{APP_HOST}:{str(APP_PORT)}/"
)
if global_version.get_is_ee_version():
if global_version.is_ee_version():
logger.notice("Running Enterprise Edition")
uvicorn.run(app, host=APP_HOST, port=APP_PORT)

View File

@@ -16,6 +16,7 @@ from danswer.configs.model_configs import (
)
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.models import SearchSettings
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.logger import setup_logger
@@ -95,6 +96,7 @@ class EmbeddingModel:
api_url: str | None,
provider_type: EmbeddingProvider | None,
retrim_content: bool = False,
heartbeat: Heartbeat | None = None,
) -> None:
self.api_key = api_key
self.provider_type = provider_type
@@ -107,6 +109,7 @@ class EmbeddingModel:
self.tokenizer = get_tokenizer(
model_name=model_name, provider_type=provider_type
)
self.heartbeat = heartbeat
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
@@ -166,6 +169,9 @@ class EmbeddingModel:
response = self._make_model_server_request(embed_request)
embeddings.extend(response.embeddings)
if self.heartbeat:
self.heartbeat.heartbeat()
return embeddings
def encode(

View File

@@ -110,8 +110,8 @@ Respond "{SKIP_SEARCH}" if:
and additional information or details would provide little or no value.
- The query is some task that does not require additional information to handle.
{GENERAL_SEP_PAT}
Conversation History:
{GENERAL_SEP_PAT}
{{chat_history}}
{GENERAL_SEP_PAT}
@@ -135,8 +135,8 @@ If there is a clear change in topic, disregard the previous messages.
Strip out any information that is not relevant for the retrieval task.
If the follow up message is an error or code snippet, repeat the same input back EXACTLY.
{GENERAL_SEP_PAT}
Chat History:
{GENERAL_SEP_PAT}
{{chat_history}}
{GENERAL_SEP_PAT}
@@ -152,8 +152,8 @@ If a broad query might yield too many results, make it detailed.
If there is a clear change in topic, ensure the query reflects the new topic accurately.
Strip out any information that is not relevant for the internet search.
{GENERAL_SEP_PAT}
Chat History:
{GENERAL_SEP_PAT}
{{chat_history}}
{GENERAL_SEP_PAT}
@@ -210,6 +210,7 @@ IMPORTANT: TRY NOT TO USE MORE THAN 5 WORDS, MAKE IT AS CONCISE AS POSSIBLE.
Focus the name on the important keywords to convey the topic of the conversation.
Chat History:
{GENERAL_SEP_PAT}
{{chat_history}}
{GENERAL_SEP_PAT}

View File

@@ -72,7 +72,8 @@ EMPTY_SAMPLE_JSON = {
JSON_PROMPT = f"""
{{system_prompt}}
{REQUIRE_JSON}
{{context_block}}{{history_block}}{{task_prompt}}
{{context_block}}{{history_block}}
{{task_prompt}}
SAMPLE RESPONSE:
```
@@ -91,6 +92,7 @@ SAMPLE RESPONSE:
# "conversation history" block
CITATIONS_PROMPT = f"""
Refer to the following context documents when responding to me.{DEFAULT_IGNORE_STATEMENT}
CONTEXT:
{GENERAL_SEP_PAT}
{{context_docs_str}}
@@ -109,10 +111,7 @@ CITATIONS_PROMPT_FOR_TOOL_CALLING = f"""
Refer to the provided context documents when responding to me.{DEFAULT_IGNORE_STATEMENT} \
You should always get right to the point, and never use extraneous language.
CHAT HISTORY:
{{history_block}}
{{task_prompt}}
{{history_block}}{{task_prompt}}
{QUESTION_PAT.upper()}
{{user_query}}

View File

@@ -3,23 +3,23 @@ from typing import Optional
import redis
from redis.client import Redis
from redis.connection import ConnectionPool
from danswer.configs.app_configs import REDIS_DB_NUMBER
from danswer.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
from danswer.configs.app_configs import REDIS_HOST
from danswer.configs.app_configs import REDIS_PASSWORD
from danswer.configs.app_configs import REDIS_POOL_MAX_CONNECTIONS
from danswer.configs.app_configs import REDIS_PORT
from danswer.configs.app_configs import REDIS_SSL
from danswer.configs.app_configs import REDIS_SSL_CA_CERTS
from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
REDIS_POOL_MAX_CONNECTIONS = 10
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
class RedisPool:
_instance: Optional["RedisPool"] = None
_lock: threading.Lock = threading.Lock()
_pool: ConnectionPool
_pool: redis.BlockingConnectionPool
def __new__(cls) -> "RedisPool":
if not cls._instance:
@@ -42,33 +42,52 @@ class RedisPool:
db: int = REDIS_DB_NUMBER,
password: str = REDIS_PASSWORD,
max_connections: int = REDIS_POOL_MAX_CONNECTIONS,
ssl_ca_certs: str = REDIS_SSL_CA_CERTS,
ssl_ca_certs: str | None = REDIS_SSL_CA_CERTS,
ssl_cert_reqs: str = REDIS_SSL_CERT_REQS,
ssl: bool = False,
) -> redis.ConnectionPool:
) -> redis.BlockingConnectionPool:
"""We use BlockingConnectionPool because it will block and wait for a connection
rather than error if max_connections is reached. This is far more deterministic
behavior and aligned with how we want to use Redis."""
# Using ConnectionPool is not well documented.
# Useful examples: https://github.com/redis/redis-py/issues/780
if ssl:
return redis.ConnectionPool(
return redis.BlockingConnectionPool(
host=host,
port=port,
db=db,
password=password,
max_connections=max_connections,
timeout=None,
health_check_interval=REDIS_HEALTH_CHECK_INTERVAL,
socket_keepalive=True,
socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS,
connection_class=redis.SSLConnection,
ssl_ca_certs=ssl_ca_certs,
ssl_cert_reqs=ssl_cert_reqs,
)
return redis.ConnectionPool(
return redis.BlockingConnectionPool(
host=host,
port=port,
db=db,
password=password,
max_connections=max_connections,
timeout=None,
health_check_interval=REDIS_HEALTH_CHECK_INTERVAL,
socket_keepalive=True,
socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS,
)
redis_pool = RedisPool()
def get_redis_client() -> Redis:
return redis_pool.get_client()
# # Usage example
# redis_pool = RedisPool()
# redis_client = redis_pool.get_client()

View File

@@ -1,8 +1,8 @@
from typing import cast
from danswer.configs.constants import KV_SEARCH_SETTINGS
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.key_value_store.factory import get_kv_store
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.search.models import SavedSearchSettings
from danswer.utils.logger import setup_logger
@@ -17,10 +17,10 @@ def get_kv_search_settings() -> SavedSearchSettings | None:
if the value is updated by another process/instance of the API server. If this reads from an in memory cache like
reddis then it will be ok. Until then this has some performance implications (though minor)
"""
kv_store = get_dynamic_config_store()
kv_store = get_kv_store()
try:
return SavedSearchSettings(**cast(dict, kv_store.load(KV_SEARCH_SETTINGS)))
except ConfigNotFoundError:
except KvKeyNotFoundError:
return None
except Exception as e:
logger.error(f"Error loading search settings: {e}")

View File

@@ -4,6 +4,7 @@ from fastapi import FastAPI
from fastapi.dependencies.models import Dependant
from starlette.routing import BaseRoute
from danswer.auth.users import control_plane_dep
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
@@ -98,6 +99,7 @@ def check_router_auth(
or depends_fn == current_curator_or_admin_user
or depends_fn == api_key_dep
or depends_fn == current_user_with_expired_token
or depends_fn == control_plane_dep
):
found_auth = True
break

View File

@@ -1,4 +1,5 @@
import math
from http import HTTPStatus
from fastapi import APIRouter
from fastapi import Depends
@@ -9,7 +10,11 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
from danswer.background.celery.tasks.pruning.tasks import (
try_creating_prune_generator_task,
)
from danswer.db.connector_credential_pair import add_credential_to_connector
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.connector_credential_pair import remove_credential_from_connector
@@ -26,17 +31,23 @@ from danswer.db.index_attempt import count_index_attempts_for_connector
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
from danswer.db.models import User
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.redis.redis_pool import get_redis_client
from danswer.server.documents.models import CCPairFullInfo
from danswer.server.documents.models import CCStatusUpdateRequest
from danswer.server.documents.models import CeleryTaskStatus
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.server.documents.models import ConnectorCredentialPairMetadata
from danswer.server.documents.models import PaginatedIndexAttempts
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from ee.danswer.background.task_name_builders import (
name_sync_external_doc_permissions_task,
)
from ee.danswer.db.user_group import validate_user_creation_permissions
logger = setup_logger()
router = APIRouter(prefix="/manage")
@@ -190,6 +201,156 @@ def update_cc_pair_name(
raise HTTPException(status_code=400, detail="Name must be unique")
@router.get("/admin/cc-pair/{cc_pair_id}/prune")
def get_cc_pair_latest_prune(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> bool:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="Connection not found for current user's permissions",
)
rcp = RedisConnectorPruning(cc_pair.id)
return rcp.is_pruning(db_session, get_redis_client())
@router.post("/admin/cc-pair/{cc_pair_id}/prune")
def prune_cc_pair(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> StatusResponse[list[int]]:
"""Triggers pruning on a particular cc_pair immediately"""
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="Connection not found for current user's permissions",
)
r = get_redis_client()
rcp = RedisConnectorPruning(cc_pair_id)
if rcp.is_pruning(db_session, r):
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail="Pruning task already in progress.",
)
logger.info(
f"Pruning cc_pair: cc_pair_id={cc_pair_id} "
f"connector_id={cc_pair.connector_id} "
f"credential_id={cc_pair.credential_id} "
f"{cc_pair.connector.name} connector."
)
tasks_created = try_creating_prune_generator_task(cc_pair, db_session, r)
if not tasks_created:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="Pruning task creation failed.",
)
return StatusResponse(
success=True,
message="Successfully created the pruning task.",
)
@router.get("/admin/cc-pair/{cc_pair_id}/sync")
def get_cc_pair_latest_sync(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> CeleryTaskStatus:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="Connection not found for current user's permissions",
)
# look up the last sync task for this connector (if it exists)
sync_task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair_id)
last_sync_task = get_latest_task(sync_task_name, db_session)
if not last_sync_task:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="No sync task found.",
)
return CeleryTaskStatus(
id=last_sync_task.task_id,
name=last_sync_task.task_name,
status=last_sync_task.status,
start_time=last_sync_task.start_time,
register_time=last_sync_task.register_time,
)
@router.post("/admin/cc-pair/{cc_pair_id}/sync")
def sync_cc_pair(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> StatusResponse[list[int]]:
# avoiding circular refs
from ee.danswer.background.celery.celery_app import (
sync_external_doc_permissions_task,
)
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="Connection not found for current user's permissions",
)
sync_task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair_id)
last_sync_task = get_latest_task(sync_task_name, db_session)
if last_sync_task and check_task_is_live_and_not_timed_out(
last_sync_task, db_session
):
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail="Sync task already in progress.",
)
logger.info(f"Syncing the {cc_pair.connector.name} connector.")
sync_external_doc_permissions_task.apply_async(
kwargs=dict(cc_pair_id=cc_pair_id),
)
return StatusResponse(
success=True,
message="Successfully created the sync task.",
)
@router.put("/connector/{connector_id}/credential/{credential_id}")
def associate_credential_to_connector(
connector_id: int,

View File

@@ -74,8 +74,8 @@ from danswer.db.models import IndexingStatus
from danswer.db.models import User
from danswer.db.models import UserRole
from danswer.db.search_settings import get_current_search_settings
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.file_store.file_store import get_default_file_store
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.server.documents.models import AuthStatus
from danswer.server.documents.models import AuthUrl
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
@@ -116,7 +116,7 @@ def check_google_app_gmail_credentials_exist(
) -> dict[str, str]:
try:
return {"client_id": get_google_app_gmail_cred().web.client_id}
except ConfigNotFoundError:
except KvKeyNotFoundError:
raise HTTPException(status_code=404, detail="Google App Credentials not found")
@@ -140,7 +140,7 @@ def delete_google_app_gmail_credentials(
) -> StatusResponse:
try:
delete_google_app_gmail_cred()
except ConfigNotFoundError as e:
except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
return StatusResponse(
@@ -154,7 +154,7 @@ def check_google_app_credentials_exist(
) -> dict[str, str]:
try:
return {"client_id": get_google_app_cred().web.client_id}
except ConfigNotFoundError:
except KvKeyNotFoundError:
raise HTTPException(status_code=404, detail="Google App Credentials not found")
@@ -178,7 +178,7 @@ def delete_google_app_credentials(
) -> StatusResponse:
try:
delete_google_app_cred()
except ConfigNotFoundError as e:
except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
return StatusResponse(
@@ -192,7 +192,7 @@ def check_google_service_gmail_account_key_exist(
) -> dict[str, str]:
try:
return {"service_account_email": get_gmail_service_account_key().client_email}
except ConfigNotFoundError:
except KvKeyNotFoundError:
raise HTTPException(
status_code=404, detail="Google Service Account Key not found"
)
@@ -218,7 +218,7 @@ def delete_google_service_gmail_account_key(
) -> StatusResponse:
try:
delete_gmail_service_account_key()
except ConfigNotFoundError as e:
except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
return StatusResponse(
@@ -232,7 +232,7 @@ def check_google_service_account_key_exist(
) -> dict[str, str]:
try:
return {"service_account_email": get_service_account_key().client_email}
except ConfigNotFoundError:
except KvKeyNotFoundError:
raise HTTPException(
status_code=404, detail="Google Service Account Key not found"
)
@@ -258,7 +258,7 @@ def delete_google_service_account_key(
) -> StatusResponse:
try:
delete_service_account_key()
except ConfigNotFoundError as e:
except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
return StatusResponse(
@@ -280,7 +280,7 @@ def upsert_service_account_credential(
DocumentSource.GOOGLE_DRIVE,
delegated_user_email=service_account_credential_request.google_drive_delegated_user,
)
except ConfigNotFoundError as e:
except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
# first delete all existing service account credentials
@@ -306,7 +306,7 @@ def upsert_gmail_service_account_credential(
DocumentSource.GMAIL,
delegated_user_email=service_account_credential_request.gmail_delegated_user,
)
except ConfigNotFoundError as e:
except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
# first delete all existing service account credentials
@@ -781,6 +781,7 @@ def connector_run_once(
detail="Connector has no valid credentials, cannot create index attempts.",
)
# Prevents index attempts for cc pairs that already have an index attempt currently running
skipped_credentials = [
credential_id
for credential_id in credential_ids
@@ -790,15 +791,15 @@ def connector_run_once(
credential_id=credential_id,
),
only_current=True,
disinclude_finished=True,
db_session=db_session,
disinclude_finished=True,
)
]
search_settings = get_current_search_settings(db_session)
connector_credential_pairs = [
get_connector_credential_pair(run_info.connector_id, credential_id, db_session)
get_connector_credential_pair(connector_id, credential_id, db_session)
for credential_id in credential_ids
if credential_id not in skipped_credentials
]

View File

@@ -268,6 +268,14 @@ class CCPairFullInfo(BaseModel):
)
class CeleryTaskStatus(BaseModel):
id: str
name: str
status: TaskStatus
start_time: datetime | None
register_time: datetime | None
class FailedConnectorIndexingStatus(BaseModel):
"""Simplified version of ConnectorIndexingStatus for failed indexing attempts"""

View File

@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.background.celery.celery_app import celery_app
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DocumentSource
@@ -28,9 +29,9 @@ from danswer.db.index_attempt import cancel_indexing_attempts_for_ccpair
from danswer.db.models import User
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.file_store.file_store import get_default_file_store
from danswer.key_value_store.factory import get_kv_store
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.llm.factory import get_default_llms
from danswer.llm.utils import test_llm
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
@@ -113,7 +114,7 @@ def validate_existing_genai_api_key(
_: User = Depends(current_admin_user),
) -> None:
# Only validate every so often
kv_store = get_dynamic_config_store()
kv_store = get_kv_store()
curr_time = datetime.now(tz=timezone.utc)
try:
last_check = datetime.fromtimestamp(
@@ -122,7 +123,7 @@ def validate_existing_genai_api_key(
check_freq_sec = timedelta(seconds=GENERATIVE_MODEL_ACCESS_CHECK_FREQ)
if curr_time - last_check < check_freq_sec:
return
except ConfigNotFoundError:
except KvKeyNotFoundError:
# First time checking the key, nothing unusual
pass
@@ -146,10 +147,6 @@ def create_deletion_attempt_for_connector_id(
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> None:
from danswer.background.celery.celery_app import (
check_for_connector_deletion_task,
)
connector_id = connector_credential_pair_identifier.connector_id
credential_id = connector_credential_pair_identifier.credential_id
@@ -193,8 +190,11 @@ def create_deletion_attempt_for_connector_id(
status=ConnectorCredentialPairStatus.DELETING,
)
# run the beat task to pick up this deletion early
check_for_connector_deletion_task.apply_async(
db_session.commit()
# run the beat task to pick up this deletion from the db immediately
celery_app.send_task(
"check_for_connector_deletion_task",
priority=DanswerCeleryPriority.HIGH,
)

View File

@@ -10,6 +10,7 @@ from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.db.engine import get_session
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.llm import fetch_provider
from danswer.db.llm import remove_llm_provider
from danswer.db.llm import update_default_provider
from danswer.db.llm import upsert_llm_provider
@@ -124,17 +125,26 @@ def list_llm_providers(
def put_llm_provider(
llm_provider: LLMProviderUpsertRequest,
is_creation: bool = Query(
True,
False,
description="True if updating an existing provider, False if creating a new one",
),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> FullLLMProvider:
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
# the result
existing_provider = fetch_provider(db_session, llm_provider.name)
if existing_provider and is_creation:
raise HTTPException(
status_code=400,
detail=f"LLM Provider with name {llm_provider.name} already exists",
)
try:
return upsert_llm_provider(
llm_provider=llm_provider,
db_session=db_session,
is_creation=is_creation,
)
except ValueError as e:
logger.exception("Failed to upsert LLM Provider")

View File

@@ -21,6 +21,9 @@ from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.search_settings import update_current_search_settings
from danswer.db.search_settings import update_search_settings_status
from danswer.document_index.factory import get_default_document_index
from danswer.file_processing.unstructured import delete_unstructured_api_key
from danswer.file_processing.unstructured import get_unstructured_api_key
from danswer.file_processing.unstructured import update_unstructured_api_key
from danswer.natural_language_processing.search_nlp_models import clean_model_name
from danswer.search.models import SavedSearchSettings
from danswer.search.models import SearchSettingsCreationRequest
@@ -30,7 +33,6 @@ from danswer.server.models import IdReturn
from danswer.utils.logger import setup_logger
from shared_configs.configs import ALT_INDEX_SUFFIX
router = APIRouter(prefix="/search-settings")
logger = setup_logger()
@@ -196,3 +198,27 @@ def update_saved_search_settings(
update_current_search_settings(
search_settings=search_settings, db_session=db_session
)
@router.get("/unstructured-api-key-set")
def unstructured_api_key_set(
_: User | None = Depends(current_admin_user),
) -> bool:
api_key = get_unstructured_api_key()
print(api_key)
return api_key is not None
@router.put("/upsert-unstructured-api-key")
def upsert_unstructured_api_key(
unstructured_api_key: str,
_: User | None = Depends(current_admin_user),
) -> None:
update_unstructured_api_key(unstructured_api_key)
@router.delete("/delete-unstructured-api-key")
def delete_unstructured_api_key_endpoint(
_: User | None = Depends(current_admin_user),
) -> None:
delete_unstructured_api_key()

View File

@@ -18,7 +18,7 @@ from danswer.db.slack_bot_config import fetch_slack_bot_configs
from danswer.db.slack_bot_config import insert_slack_bot_config
from danswer.db.slack_bot_config import remove_slack_bot_config
from danswer.db.slack_bot_config import update_slack_bot_config
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.server.manage.models import SlackBotConfig
from danswer.server.manage.models import SlackBotConfigCreationRequest
from danswer.server.manage.models import SlackBotTokens
@@ -212,5 +212,5 @@ def put_tokens(
def get_tokens(_: User | None = Depends(current_admin_user)) -> SlackBotTokens:
try:
return fetch_tokens()
except ConfigNotFoundError:
except KvKeyNotFoundError:
raise HTTPException(status_code=404, detail="No tokens found")

View File

@@ -38,7 +38,7 @@ from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.db.users import get_user_by_email
from danswer.db.users import list_users
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.key_value_store.factory import get_kv_store
from danswer.server.manage.models import AllUsersResponse
from danswer.server.manage.models import UserByEmail
from danswer.server.manage.models import UserInfo
@@ -367,7 +367,7 @@ def verify_user_logged_in(
# if auth type is disabled, return a dummy user with preferences from
# the key-value store
if AUTH_TYPE == AuthType.DISABLED:
store = get_dynamic_config_store()
store = get_kv_store()
return fetch_no_auth_user(store)
raise HTTPException(
@@ -405,7 +405,7 @@ def update_user_default_model(
) -> None:
if user is None:
if AUTH_TYPE == AuthType.DISABLED:
store = get_dynamic_config_store()
store = get_kv_store()
no_auth_user = fetch_no_auth_user(store)
no_auth_user.preferences.default_model = request.default_model
set_no_auth_user_preferences(store, no_auth_user.preferences)
@@ -433,7 +433,7 @@ def update_user_assistant_list(
) -> None:
if user is None:
if AUTH_TYPE == AuthType.DISABLED:
store = get_dynamic_config_store()
store = get_kv_store()
no_auth_user = fetch_no_auth_user(store)
no_auth_user.preferences.chosen_assistants = request.chosen_assistants
@@ -487,7 +487,7 @@ def update_user_assistant_visibility(
) -> None:
if user is None:
if AUTH_TYPE == AuthType.DISABLED:
store = get_dynamic_config_store()
store = get_kv_store()
no_auth_user = fetch_no_auth_user(store)
preferences = no_auth_user.preferences
updated_preferences = update_assistant_list(preferences, assistant_id, show)

Some files were not shown because too many files have changed in this diff Show More