Compare commits

..

195 Commits

Author SHA1 Message Date
Weves
5f82de7c45 Debug test 2024-09-23 11:05:27 -07:00
pablodanswer
45f67368a2 Add support for o1 (#2538)
* add o1 support + bump litellm/openai

* ports

* update exception message for testing
2024-09-22 23:16:28 +00:00
pablodanswer
014ba9e220 Begin distinguishing upsert operations for clarity (#2535)
* additional clarity for llm provider creation / updates

* update provider APIs

* update typing (minor)
2024-09-21 22:36:22 +00:00
pablodanswer
ba64543dd7 Updated modals for clarity (#2529)
* udpated modals for clarity

* fix build
2024-09-21 19:55:54 +00:00
pablodanswer
18c62a0c24 Add additional custom tooling configuration (#2426)
* add custom headers

* add tool seeding

* squash

* tmep

* validated

* rm

* update typing

* update alembic

* update import name

* reformat

* alembic
2024-09-20 23:12:52 +00:00
Chris Weaver
33f555922c Fix duplicate users from slack / web (#2530) 2024-09-20 21:51:33 +00:00
pablodanswer
05f6f6d5b5 update default search assistant selection (#2527)
* update default search assistant selection

* update language
2024-09-20 21:21:44 +00:00
hagen-danswer
19dae1d870 Wrote tests for the chat apis (#2525)
* Wrote tests for the chat apis

* slight changes to the case
2024-09-20 19:00:03 +00:00
rkuo-danswer
6d859bd37c try adding build essential (#2526) 2024-09-20 11:51:44 -07:00
pablodanswer
122e3fa3fa Access type (#2523) 2024-09-20 11:16:37 -07:00
pablodanswer
87b542b335 align alembic 2024-09-20 11:13:00 -07:00
pablodanswer
00229d2abe Add start date to persona (#2407)
* add start date to persona

* remove logs

* rename

* update assistant editor

* update alembic

* update alembic

* update alembic

* udpate alembic

* remove rebase artifacts
2024-09-20 16:39:34 +00:00
pablodanswer
5f2644985c Route name (#2520)
* clearer refresh logic

* rename path
2024-09-20 15:44:28 +00:00
pablodanswer
c82a36ad68 Saml account fastapi deletion (#2512)
* saml account fastapi deletion

* update error detail
2024-09-20 00:20:50 +00:00
hagen-danswer
16d1c19d9f Added bool to disable chat_session_id check for search_docs for api 2024-09-19 17:36:46 -07:00
pablodanswer
9f179940f8 Asana connector (community originated) (#2485)
* initial Asana connector

* hint on how to get Asana workspace ID

* re-format with black

* re-order imports

* update asana connector for clarity

* minor robustification

* minor update to naming

* update for best practice

* update connector

---------

Co-authored-by: Daniel Naber <naber@danielnaber.de>
2024-09-19 23:54:18 +00:00
pablodanswer
8a8e2b310e Assistants panel rework (#2509)
* update user model

* squash - update assistant gallery

* rework assistant display logic + ux

* update tool + assistant display

* update a couple function names

* update typing + some logic

* remove unnecessary comments

* finalize functionality

* updated logic

* fully functional

* remove logs + ports

* small update to logic

* update typing

* allow seeding of display priority

* reorder migrations

* update for alembic
2024-09-19 23:36:15 +00:00
hagen-danswer
2274cab554 Added permission syncing (#2340)
* Added permission syncing on the backend

* Rewored to work with celery

alembic fix

fixed test

* frontend changes

* got groups working

* added comments and fixed public docs

* fixed merge issues

* frontend complete!

* frontend cleanup and mypy fixes

* refactored connector access_type selection

* mypy fixes

* minor refactor and frontend improvements

* get to fetch

* renames and comments

* minor change to var names

* got curator stuff working

* addressed pablo's comments

* refactored user_external_group to reference users table

* implemented polling

* small refactor

* fixed a whoopsies on the frontend

* added scripts to seed dummy docs and test query times

* fixed frontend build issue

* alembic fix

* handled is_public overlap

* yuhong feedback

* added more checks for sync

* black

* mypy

* fixed circular import

* todos

* alembic fix

* alembic
2024-09-19 22:07:36 +00:00
pablodanswer
ef104e9a82 Non-spotfix deletion of users (#2499)
* add description / robustify

* additional minor robustification (ideally we organized cascades slightly better)

* update deletion for simplicity

* minor typing update
2024-09-19 20:02:36 +00:00
hagen-danswer
a575d7f1eb Citations prompt for slack now includes thread history (#2510) 2024-09-19 19:31:26 +00:00
pablodanswer
f404c4b448 Move code block default language creation to citation processing (#2501)
* move code block default language creation to citaiton processing

* add test cases

* update copy
2024-09-19 06:00:58 +00:00
rkuo-danswer
3884f1d70a Bugfix/larger test runner (#2508)
* add pip retries to the github workflows too

* let's try running on amd64 ... docker builds are unusually flaky

* bump

* try large

* no yaml anchors

* switch back down to Amd64

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-19 05:36:07 +00:00
rkuo-danswer
bc9d5fece7 prevent trying to submit to jobclient when it can't take any more work (reduces log spam) (#2482) 2024-09-19 04:01:15 +00:00
rkuo-danswer
bb279a8580 add pip retries. should help with github's occasional flaky network during build/test (#2506) 2024-09-19 00:46:41 +00:00
pablodanswer
a9403016c9 fix basic auth (#2505) 2024-09-18 22:45:58 +00:00
hagen-danswer
f3cea79c1c Deleting a connector should redirect to the indexing status page (#2504)
* Deleting a connector should redirect to the indexing status page

* minor update to dev background jobs

* update refresh logic

* remove print statement

---------

Co-authored-by: pablodanswer <pablo@danswer.ai>
2024-09-18 21:38:35 +00:00
hagen-danswer
54bb79303c corrected error message (#2502) 2024-09-18 19:13:28 +00:00
pablodanswer
d3dfabb20e fix parentheses (#2486) 2024-09-18 18:39:23 +00:00
pablodanswer
7d1ec1095c proper z index for chat bubbles (#2500) 2024-09-18 18:02:50 +00:00
rkuo-danswer
f531d071af Feature/background deletion (#2337)
* first cut at redis

* some new helper functions for the db

* ignore kombu tables in alembic migrations (used by celery)

* multiline commands for readability, add vespa_metadata_sync queue to worker

* typo fix

* fix returning tuple fields

* add constants

* fix _get_access_for_document

* docstrings!

* fix double function declaration and typing

* fix type hinting

* add a global redis pool

* Add get_document function

* use task_logger in various celery tasks

* add celeryconfig.py to simplify configuration. Will be used in a subsequent commit

* Add celery redis helper. used in a subsequent PR

* kombu warning getting spammy since celery is not self managing its queue in Postgres any more

* add last_modified and last_synced to documents

* fix task naming convention

* use celeryconfig.py

* the big one. adds queues and tasks, updates functions to use the queues with priorities, etc

* change vespa index log line to debug

* mypy fixes

* update alembic migration

* fix fence ordering, rename to "monitor", fix fetch_versioned_implementation call

* mypy

* switch to monotonic time

* fix startup dependencies on redis

* rebase alembic migration

* kombu cleanup - fail silently

* mypy

* add redis_host environment override

* update REDIS_HOST env var in docker-compose.dev.yml

* update the rest of the docker files

* in flight

* harden indexing-status endpoint against db changes happening in the background.  Needs further improvement but OK for now.

* allow no task syncs to run because we create certain objects with no entries but initially marked as out of date

* add back writing to vespa on indexing

* actually working connector deletion

* update contributing guide

* backporting fixes from background_deletion

* renaming cache to cache_volume

* add redis password to various deployments

* try setting up pr testing for helm

* fix indent

* hopefully this release version actually exists

* fix command line option to --chart-dirs

* fetch-depth 0

* edit values.yaml

* try setting ct working directory

* bypass testing only on change for now

* move files and lint them

* update helm testing

* some issues suggest using --config works

* add vespa repo

* add postgresql repo

* increase timeout

* try amd64 runner

* fix redis password reference

* add comment to helm chart testing workflow

* rename helm testing workflow to disable it

* adding clarifying comments

* address code review

* missed a file

* remove commented warning ... just not needed

* fix imports

* refactor to use update_single

* mypy fixes

* add vespa test

* add db refresh to connector deletion

* code review fixes

* move monitor_usergroup_taskset to ee, improve logging

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-18 16:50:11 +00:00
Chris Weaver
4218814385 Add flow to query history CSV (#2492) 2024-09-18 14:23:56 +00:00
rkuo-danswer
e662e3b57d clarify ssl cert reqs (#2494)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-18 05:35:57 +00:00
pablodanswer
2073820e33 Update default assistants to all visible (#2490)
* update default assistants to all visible

* update with catch-all

* minor update

* update
2024-09-18 02:08:11 +00:00
Chris Weaver
5f25b243c5 Add back llm_chunks_indices (#2491) 2024-09-18 01:21:31 +00:00
pablodanswer
a9427f190a Extend time range (contributor submission) (#2484)
* added new options for time range; removed duplicated code

* refactor + remove unused code

---------

Co-authored-by: Zoltan Szabo <zoltan.szabo@eaudeweb.ro>
2024-09-17 22:36:25 +00:00
pablodanswer
18fbe9d7e8 Warn users of gpu-sensitive operation (#2488)
* warn users of gpu-sensitive operation

* update copy
2024-09-17 21:59:43 +00:00
Chris Weaver
75c9b1cafe Fix concatenate string with toolcallkickoff issue (#2487) 2024-09-17 21:25:06 +00:00
rkuo-danswer
632a8f700b Feature/celery backend db number (#2475)
* use separate database number for celery result backend

* add comments

* add env var for celery's result_expires

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-17 21:06:36 +00:00
pablodanswer
cd58c96014 Memoize AI message component (#2483)
* memoize AI message component

* rename memoized file

* remove "zz"

* update name

* memoize for coverage

* add display name
2024-09-17 18:47:23 +00:00
pablodanswer
c5032d25c9 Minor clarity update for connectors (#2480) 2024-09-17 10:25:39 -07:00
pablodanswer
72acde6fd4 Handle tool errors in display properly (can show valueError to user) (#2481)
* handle tool errors in display properly (can show valueerrors to user)

* update for clarity
2024-09-17 17:08:46 +00:00
rkuo-danswer
5596a68d08 harden migration (#2476)
* harden migration

* remove duplicate line
2024-09-17 16:44:53 +00:00
Weves
5b18409c89 Change user-message to user-prompt 2024-09-16 21:53:27 -07:00
Chris Weaver
84272af5ac Add back scrolling to ExceptionTraceModal (#2473) 2024-09-17 02:25:53 +00:00
pablodanswer
6bef70c8b7 ensure disabled gets propagated 2024-09-16 19:27:31 -07:00
pablodanswer
7f7559e3d2 Allow users to share assistants (#2434)
* enable assistant sharing

* functional

* remove logs

* revert ports

* remove accidental update

* minor updates to copy

* update formatting

* update for merge queue
2024-09-17 01:35:29 +00:00
Chris Weaver
7ba829a585 Add top_documents to APIs (#2469)
* Add top_documents

* Fix test

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
2024-09-16 23:48:33 +00:00
trial-danswer
8b2ecb4eab EE movement followup for Standard Answers (#2467)
* Move StandardAnswer to EE section of danswer/db/models

* Move StandardAnswer DB layer to EE

* Add EERequiredError for distinct error handling here

* Handle EE fallback for slack bot config

* Migrate all standard answer models to ee

* Flagging categories for removal

* Add missing versioned impl for update_slack_bot_config

---------

Co-authored-by: danswer-trial <danswer-trial@danswer-trials-MacBook-Pro.local>
2024-09-16 22:05:53 +00:00
pablodanswer
2dd3870504 Add ability to specify persona in API request (#2302)
* persona

* all prepared excluding configuration

* more sensical model structure

* update tstream

* type updates

* rm

* quick and simple updates

* minor updates

* te

* ensure typing + naming

* remove old todo + rebase update

* remove unnecessary check
2024-09-16 21:31:01 +00:00
pablodanswer
df464fc54b Allow for CORS Origin Setting (#2449)
* allow setting of CORS origin

* simplify

* add environment variable + rename

* slightly more efficient

* simplify so mypy doens't complain

* temp

* go back to my preferred formatting
2024-09-16 18:54:36 +00:00
pablodanswer
96b98fbc4a Make it impossible to switch to non-image (#2440)
* make it impossible to switch to non-image

* revert ports

* proper provider support

* remove unused imports

* minor rename

* simplify interface

* remove logs
2024-09-16 18:35:40 +00:00
trial-danswer
66cf67d04d hotfix: sqlalchemy default -> server_default (#2442)
Co-authored-by: danswer-trial <danswer-trial@danswer-trials-MacBook-Pro.local>
2024-09-16 17:49:01 +00:00
pablodanswer
285bdbbaf9 Fix stop generating locally (#2452)
* fix stop generating locally

* .
2024-09-15 23:55:30 +00:00
pablodanswer
e2c37d6847 Test stream + Update Copy (#2317)
* update copy + conditional ordering

* answer stream checks

* update

* add basic tests for chat streams

* slightly simplify

* fix typing

* quick typing updates + nits
2024-09-15 19:40:48 +00:00
Yuhong Sun
3ff2ba7ee4 k (#2450) 2024-09-15 17:32:58 +00:00
pablodanswer
290f4f0f8c add some minor ux updates (#2441) 2024-09-15 08:29:31 +00:00
rkuo-danswer
3c934a93cd using is_up_to_date cached outside of the fence was causing a race condition where the same sync could be kicked off again (#2433) 2024-09-15 06:27:05 +00:00
Yuhong Sun
a51b0f636e Logs from API Server Container on Merge Queue (#2448)
* k

* k
2024-09-14 20:32:18 +00:00
pablodanswer
a50c2e30ec Very minor polish (#2445)
* fix minor polish

* cleaner chat flow

* remove keys

* slight robustification to copying
2024-09-14 17:54:29 +00:00
pablodanswer
ee278522ef update indexing status clarity (#2446) 2024-09-14 17:19:55 +00:00
trial-danswer
430c9a47d7 Match any/all keywords in Standard Answers (#2443)
* migration: add column "match_any_keywords" to StandardAnswer

* Implement any/all keyword matching for standard answers

* Add match_any_keywords to non-searchable fields

* Remove stray print

* Simplify Slack messages for any and all cases

---------

Co-authored-by: danswer-trial <danswer-trial@danswer-trials-MacBook-Pro.local>
2024-09-14 05:28:07 +00:00
hj-danswer
974f85da66 Migrate standard answers implementations to ee/ (#2378)
* Migrate standard answers implementations to ee/

* renaming

* Clean up slackbot non-ee standard answers import

* Move backend api/manage/standard_answer route to ee

* Move standard answers web UI to ee

* Hide standard answer controls in bot edit page

* Kwargs for fetch_versioned_implementation

* Add docstring explaining return types for handle_standard_answers

* Consolidate blocks into ee/handle_standard_answers

---------

Co-authored-by: Hyeong Joon Suh <hyeongjoonsuh@Hyeongs-MacBook-Pro.local>
Co-authored-by: danswer-trial <danswer-trial@danswer-trials-MacBook-Pro.local>
2024-09-14 01:57:03 +00:00
hagen-danswer
a63cb9da43 fixed /danswer handling (#2436)
* fixed

* mypy

* cleaned up and commented

* mypy

* Update handle_regular_answer.py
2024-09-14 01:21:13 +00:00
rkuo-danswer
d807ad7699 fix document set connection removal sync, add tests for document set and user group removal (#2437) 2024-09-14 01:01:26 +00:00
hj-danswer
3cb00de6d4 Support regex in standard answers (#2377)
* Support regex in standard answers

* fix mypy

* Add match_regex boolean column to StandardAnswer

* Add match_regex flag and validation to Pydantic models

* GET /manage/admin/standard-answer: add match_regex to create_standard_answer

* PATCH /manage/admin/standard-answer/🆔 add match_regex to update_standard_answer

* Add "Match Regex" toggle to standard answer form

* Decode error pattern in case it's bytes

* Refactor regex support to use match_regex flag instead of supplemental tuple

* Better error handling for invalid regexes

* Show "match regex" in table and style keywords appropriately

* Fix stale UI copy for non-"match_regex" branch

* Fix stale docstring in find_matching_standard_answers

* Update down_revision to reflect most recent migration

* Update UI copy

* Initial implementation of match group display

* Fix pydantic StandardAnswer vs SQLAlchemy StandardAnswer model usage

* Update docstring return type

* Fix missing key prop

---------

Co-authored-by: Hyeong Joon Suh <hyeongjoonsuh@Hyeongs-MacBook-Pro.local>
Co-authored-by: danswer-trial <danswer-trial@danswer-trials-MacBook-Pro.local>
2024-09-14 00:07:42 +00:00
Chris Weaver
da6e46ae75 Slack flow improvements (#2366) 2024-09-13 16:56:45 -07:00
pablodanswer
648c2531f9 Add custom tool chat session / message ID dynamic prompting (#2404)
* add custom tool chat session / message ID dynamic prompting

* update some formatting

* code organization + remove unnecessary card

* remove log

* update for clarity
2024-09-13 18:42:21 +00:00
pablodanswer
fc98c560a4 Add fix for logging (#2431) 2024-09-13 11:27:20 -07:00
pablodanswer
566f44fcd6 Minor update to llm image ability tracking (#2423)
* minor update to llm image ability tracking

* quick robustification
2024-09-13 17:24:51 +00:00
rkuo-danswer
2fe49e5efb add ssl testing for redis against a cloud instance (#2422) 2024-09-13 10:28:04 -07:00
rkuo-danswer
f58acd4e2a Add redis to helm chart (#2390) 2024-09-13 10:26:51 -07:00
pablodanswer
53008a0271 update multipass indeixng server default 2024-09-13 10:24:26 -07:00
pablodanswer
13278663d9 Update refresh + robustify embeddings (#2420)
* update refresh + robustify embeddings

* squash
2024-09-13 14:26:33 +00:00
pablodanswer
31ca6857fb Custom Refresh on Client Side (#2376) 2024-09-13 00:04:03 -07:00
pablodanswer
6dd91414be delete chat session immediately 2024-09-13 00:02:43 -07:00
rkuo-danswer
140c34e59e ephemeral behavior for redis (#2373)
* ephemeral behavior for redis

* notes for redis command line consistency
2024-09-13 04:48:50 +00:00
rkuo-danswer
da8e68b320 reformat celery logging to match danswer style logging across services (#2409)
* reformat celery logging to match danswer style logging across services

* mypy fixes

* handle logfile argument
2024-09-13 01:51:51 +00:00
hagen-danswer
e9a616e579 Added search_doc_ids to the simple api to allow for skipping search (#2421)
* Added search_doc_ids to the simple api to allow for skipping search

* comment

* fixed behaviour
2024-09-12 23:22:41 +00:00
pablodanswer
cb2169f2a3 Warm up reranker on model switch (#2408)
* warm up reranker on model switch

* properly type

* fix issue

* Update search_settings.py
2024-09-12 22:12:17 +00:00
pablodanswer
79aa5dd6e0 add a tiny bit of clarity to index doc counts (#2414) 2024-09-12 21:59:10 +00:00
hagen-danswer
604ebafe6c simple apis now cited/context doc indices (#2419)
* simple apis now cited/context doc indices

* minor fixes
2024-09-12 21:29:24 +00:00
pablodanswer
a2d775efbd Reformatted tailwind config (#2417)
* reformatted tailwind config

* minor update
2024-09-12 19:41:11 +00:00
rkuo-danswer
641690e3f7 fix enabling ssl in connection pool (#2418) 2024-09-12 19:18:04 +00:00
rkuo-danswer
eebf98e3a6 fix setting redis_scheme (#2416) 2024-09-12 18:07:38 +00:00
rkuo-danswer
4bc4da29f5 add SSL parameter support for redis (#2389)
* add SSL parameter support for redis

* add ssl support to redis pool
2024-09-12 16:18:11 +00:00
pablodanswer
7af572d0e7 display only failed (#2413) 2024-09-12 16:01:17 +00:00
pablodanswer
58bdf9d684 Add connector deletion failure message (#2392) 2024-09-11 22:38:15 -07:00
pablodanswer
f69922fff7 Add environment variable for setting vespa search threads (#2400) 2024-09-11 22:37:38 -07:00
pablodanswer
d4d37c9cdd add bedrock models (#2405) 2024-09-12 04:34:43 +00:00
Yuhong Sun
2654df49fd Update CONTRIBUTING.md 2024-09-11 19:17:23 -07:00
pablodanswer
aee5fcd4e0 Add env variables for overriding embedding batch size (#2395)
* add env variabels for overriding

* proper ports

* proper overrides
2024-09-12 00:51:45 +00:00
pablodanswer
2c77dd241b Add error table to re-indexing (#2388)
* add error table to re-indexing

* robustify

* update with proper comment

* add popup

* update typo
2024-09-11 22:55:55 +00:00
pablodanswer
d90c90dd92 simplify unnecessary display logic (#2406) 2024-09-11 21:35:50 +00:00
pablodanswer
2c971cf774 add claude image-support 2024-09-11 13:31:27 -07:00
trial-danswer
eab55bdd85 Misc clarifications for CONTRIBUTING.md (#2401)
* Reorder and clarify dependency installation instructions

* Clarify instructions for local development with Docker external deps vs full Docker stack

* Final words at the end of the local setup process

---------

Co-authored-by: danswer-trial <danswer-trial@danswer-trials-MacBook-Pro.local>
2024-09-11 19:16:37 +00:00
rkuo-danswer
f4f2fb5943 Bugfix/connector deletion test (#2402)
* fixes a bug with deleting connectors and foreign keys

* test foreign key handling on deletion
2024-09-11 12:04:27 -07:00
rkuo-danswer
71f2f1a90a fixes a bug with deleting connectors and foreign keys (#2398) 2024-09-11 12:03:51 -07:00
hagen-danswer
74a2271422 Added HARD_DELETE_CHATS to environment variables (#2397) 2024-09-11 18:08:29 +00:00
trial-danswer
d42fb6ce34 Add link to macOS contributions doc for installing Python 3.11 (#2396)
Co-authored-by: danswer-trial <danswer-trial@danswer-trials-MacBook-Pro.local>
2024-09-11 17:45:52 +00:00
pablodanswer
0d749ebd46 add ccpair id to logging (#2391) 2024-09-11 01:27:03 +00:00
pablodanswer
9f6e8bd124 Improve Dev Experience (#2347)
* clean interfaces + improve dex experience

* update formatting

* update ports

* ports

* remove some number of unnecessary lines

* remove unnecssary isPublicGroupSelector checks in all spots

* add comment

* update building
2024-09-10 20:49:04 +00:00
pablodanswer
3a2a6abed4 Add basic virtualization (#2370)
* add basic virtualization

* functioning perfectly

* squash

* change ports

* remove some comments

* remove comment

* update buffering clarity
2024-09-10 19:06:04 +00:00
pablodanswer
07f49a384f Update spread order (#2386)
* update spread

* update
2024-09-10 18:04:47 +00:00
rkuo-danswer
f1c5e80f17 Feature/background processing (#2275)
* first cut at redis

* some new helper functions for the db

* ignore kombu tables in alembic migrations (used by celery)

* multiline commands for readability, add vespa_metadata_sync queue to worker

* typo fix

* fix returning tuple fields

* add constants

* fix _get_access_for_document

* docstrings!

* fix double function declaration and typing

* fix type hinting

* add a global redis pool

* Add get_document function

* use task_logger in various celery tasks

* add celeryconfig.py to simplify configuration. Will be used in a subsequent commit

* Add celery redis helper. used in a subsequent PR

* kombu warning getting spammy since celery is not self managing its queue in Postgres any more

* add last_modified and last_synced to documents

* fix task naming convention

* use celeryconfig.py

* the big one. adds queues and tasks, updates functions to use the queues with priorities, etc

* change vespa index log line to debug

* mypy fixes

* update alembic migration

* fix fence ordering, rename to "monitor", fix fetch_versioned_implementation call

* mypy

* switch to monotonic time

* fix startup dependencies on redis

* rebase alembic migration

* kombu cleanup - fail silently

* mypy

* add redis_host environment override

* update REDIS_HOST env var in docker-compose.dev.yml

* update the rest of the docker files

* harden indexing-status endpoint against db changes happening in the background.  Needs further improvement but OK for now.

* allow no task syncs to run because we create certain objects with no entries but initially marked as out of date

* add back writing to vespa on indexing

* update contributing guide

* backporting fixes from background_deletion

* renaming cache to cache_volume

* add redis password to various deployments

* try setting up pr testing for helm

* fix indent

* hopefully this release version actually exists

* fix command line option to --chart-dirs

* fetch-depth 0

* edit values.yaml

* try setting ct working directory

* bypass testing only on change for now

* move files and lint them

* update helm testing

* some issues suggest using --config works

* add vespa repo

* add postgresql repo

* increase timeout

* try amd64 runner

* fix redis password reference

* add comment to helm chart testing workflow

* rename helm testing workflow to disable it

* adding clarifying comments

* address code review

* missed a file

* remove commented warning ... just not needed

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-10 16:28:19 +00:00
pablodanswer
b7ad810d83 Prevent spam search (#2367) 2024-09-10 08:44:50 -07:00
pablodanswer
99b28643f7 show groups if they exist for user (#2384) 2024-09-10 15:14:30 +00:00
rkuo-danswer
f52d1142eb Fail instead of continuing if vespa cannot be reached within the time… (#2379)
* Fail instead of continuing if vespa cannot be reached within the timeout period

* improve startup readability

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-10 03:10:25 +00:00
pablodanswer
e563746730 Consent screen (#2381)
* update

* add consent popup

* rm
2024-09-10 02:40:32 +00:00
Yuhong Sun
aa86830bde mypy 2024-09-09 16:43:45 -07:00
James Jordan
4558351801 Zendesk tickets (#2192) 2024-09-09 16:36:53 -07:00
Sebastian Müller
a4dcae57cd Google Drive Plaintext Types (#2371) 2024-09-09 15:37:47 -07:00
pablodanswer
dbd56f946f address pablo's nits (#2368) 2024-09-09 14:44:27 -07:00
hj-danswer
e4e4765c60 Add user when they interact outside of UI (e.g. Slack bot) (#2369)
* Add user when they interact outside of UI (e.g. Slack bot)

* fix mypy errors

* don't use user manager to avoid async messiness

* fix email is none scenario

* fix mypy

* make code slightly clearer

* PR comments

* get slack email in generate button as well

* fix alembic migration

* update name to be more descriptive

---------

Co-authored-by: Hyeong Joon Suh <hyeongjoonsuh@Hyeongs-MacBook-Pro.local>
2024-09-09 20:21:31 +00:00
rkuo-danswer
c967f53c02 docker versions have been deprecated for a while, so fixing the annoying warning (#2372) 2024-09-09 18:26:12 +00:00
pablodanswer
3a9b964d5c Add Litellm Rerank proxy (#2346)
* add ability ot set reranking litellm proxy

* add fully functional rerank litellm cards

* minor formatting enforcement

* remove logs
2024-09-09 15:57:01 +00:00
Yuhong Sun
f04ecbf87a Un-bump nltk due to llamaindex issue 2024-09-08 16:39:19 -07:00
Shukant Pal
362156f97e Model inference for connector classifier on queries (#2137) 2024-09-08 14:46:00 -07:00
Andres Jose Sebastian Rincon Gonzalez
3fa9676478 [1802] adjust the code to support a different db schemas (#1803) 2024-09-08 14:16:54 -07:00
Chris Weaver
be4b6189d2 Fix streaming auth locally (#2357) 2024-09-08 14:01:26 -07:00
pablodanswer
ace041415a Clearer onboarding + Provider Updates (#2361) 2024-09-08 13:35:20 -07:00
Yuhong Sun
148c2a7375 Remove wordnet (#2365) 2024-09-08 12:34:09 -07:00
pablodanswer
1555ac9dab More explicit credential creation flow (#2363)
* more explcit drive credential creation flow

* remove logs

* update naming

* fix user-contributed formatting

* fix (^) v2
2024-09-08 12:09:23 -07:00
Weves
80de408cef Fix formatting 2024-09-08 12:09:14 -07:00
Cola Chen
e20c825e16 Notion Connector to skip reading external blocks in NotionConnector
The commit skips reading 'external_object_instance_page' blocks in the NotionConnector due to the lack of support in the Notion API. This change is in response to the issue #1761.

Co-authored-by: Cola Chen <6825116+colachg@users.noreply.github.com>
2024-09-08 11:34:04 -07:00
mattboret
b0568ac8ae Sharepoint: Fix get all sites (#1700)
Co-authored-by: Matthieu Boret <matthieu.boret@fr.clara.net>
2024-09-08 11:28:11 -07:00
Art Matsak
0896d3b7da Fix content extraction from JIRA with API v2 vs. v3 (#1678) 2024-09-08 11:27:14 -07:00
Kshitiz Gupta
87b27046bd changes to the docker file for mac (#1773) 2024-09-08 11:02:18 -07:00
dependabot[bot]
5e9c6d1499 Bump aiohttp from 3.9.4 to 3.10.2 in /backend/requirements (#2097)
Bumps [aiohttp](https://github.com/aio-libs/aiohttp) from 3.9.4 to 3.10.2.
- [Release notes](https://github.com/aio-libs/aiohttp/releases)
- [Changelog](https://github.com/aio-libs/aiohttp/blob/master/CHANGES.rst)
- [Commits](https://github.com/aio-libs/aiohttp/compare/v3.9.4...v3.10.2)

---
updated-dependencies:
- dependency-name: aiohttp
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-09-08 10:59:47 -07:00
dependabot[bot]
50211ec401 Bump nltk from 3.8.1 to 3.9 in /backend/requirements (#2174)
Bumps [nltk](https://github.com/nltk/nltk) from 3.8.1 to 3.9.
- [Changelog](https://github.com/nltk/nltk/blob/develop/ChangeLog)
- [Commits](https://github.com/nltk/nltk/compare/3.8.1...3.9)

---
updated-dependencies:
- dependency-name: nltk
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-09-08 10:50:36 -07:00
Bart Schuller
6012a7cbd9 Fix multilingual .env embedding dimension (#1976) 2024-09-08 10:25:07 -07:00
dependabot[bot]
1e4b27185d Bump torch from 2.0.1 to 2.2.0 in /backend/requirements (#1933)
Bumps [torch](https://github.com/pytorch/pytorch) from 2.0.1 to 2.2.0.
- [Release notes](https://github.com/pytorch/pytorch/releases)
- [Changelog](https://github.com/pytorch/pytorch/blob/main/RELEASE.md)
- [Commits](https://github.com/pytorch/pytorch/compare/v2.0.1...v2.2.0)

---
updated-dependencies:
- dependency-name: torch
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-09-08 10:17:17 -07:00
Moshe Zada
0c66da17bb Web Connector - Get doc_updated_at from Last-Modified header (#1693) 2024-09-08 10:05:04 -07:00
Art Matsak
d985cd4352 Fix JIRA comment indexing when author has no email (#1663) 2024-09-08 09:43:09 -07:00
Yuhong Sun
c8891a5829 Remove LangChain Community (#2362) 2024-09-08 09:41:20 -07:00
Art Matsak
51a13f5fc7 Implement indexing of simple tables in Word files (#1651) 2024-09-08 09:38:46 -07:00
dependabot[bot]
57c1deb8b8 Bump braces from 3.0.2 to 3.0.3 in /web (#1628)
Bumps [braces](https://github.com/micromatch/braces) from 3.0.2 to 3.0.3.
- [Changelog](https://github.com/micromatch/braces/blob/master/CHANGELOG.md)
- [Commits](https://github.com/micromatch/braces/compare/3.0.2...3.0.3)

---
updated-dependencies:
- dependency-name: braces
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-09-07 21:06:34 -07:00
dependabot[bot]
e2e04af7e2 Bump msal from 1.26.0 to 1.28.0 in /backend/requirements (#1626)
Bumps [msal](https://github.com/AzureAD/microsoft-authentication-library-for-python) from 1.26.0 to 1.28.0.
- [Release notes](https://github.com/AzureAD/microsoft-authentication-library-for-python/releases)
- [Commits](https://github.com/AzureAD/microsoft-authentication-library-for-python/compare/1.26.0...1.28.0)

---
updated-dependencies:
- dependency-name: msal
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-09-07 21:05:11 -07:00
lombax85
c1735fcd3a Google Drive connector - txt and markdown support (#1469) 2024-09-07 20:28:23 -07:00
hj-danswer
b43e5735d7 Use user information in Slack bot DMs (#2360)
* Use user information from Slack bot DMs

* fix lint

---------

Co-authored-by: Hyeong Joon Suh <hyeongjoonsuh@Hyeongs-MacBook-Pro.local>
2024-09-08 03:08:24 +00:00
pablodanswer
7d4f8ef4e8 Minor Confluence Fixes for Robustification (#2349)
* add connector config

* update confluence connector
2024-09-08 01:39:49 +00:00
Weves
7c03b6f521 Fix responses for HTTPExceptions 2024-09-07 17:40:21 -07:00
Chris Weaver
ccf986808c Add retries (#2358)
* Add retries

* fix

* add

* remove --build

* Remove cache-to

* Don't push

* Add back push

* Add newline

* Remove alembic logs
2024-09-08 00:12:32 +00:00
pablodanswer
350482e53e Squash misc UX bugs (#2356) 2024-09-07 14:26:14 -07:00
pablodanswer
fb3d7330fa minor QOL improvement on first chat (#2353) 2024-09-07 14:25:05 -07:00
Yuhong Sun
6cec31088d CONTRIBUTING updates (#2354) 2024-09-07 14:05:36 -07:00
pablodanswer
491f3254a5 regeneration - don't remove human message unnecessarily 2024-09-06 15:38:02 -07:00
pablodanswer
5abf67fbf0 PDF metadata + list defaults (#2341)
* validate web list

* update pdf extraction of metadat

* remove pdf + log

* stricter type enforcing

* fix up indexing widths

* minor formatting

* add list case

* check for empty metadata
2024-09-06 21:21:24 +00:00
rkuo-danswer
2933c3598b first cut at redis (#2226)
* first cut at redis

* fix startup dependencies on redis

* kombu cleanup - fail silently

* mypy

* add redis_host environment override

* update REDIS_HOST env var in docker-compose.dev.yml

* update the rest of the docker files

* update contributing guide

* renaming cache to cache_volume

* add redis password to various deployments

* try setting up pr testing for helm

* fix indent

* hopefully this release version actually exists

* fix command line option to --chart-dirs

* fetch-depth 0

* edit values.yaml

* try setting ct working directory

* bypass testing only on change for now

* move files and lint them

* update helm testing

* some issues suggest using --config works

* add vespa repo

* add postgresql repo

* increase timeout

* try amd64 runner

* fix redis password reference

* add comment to helm chart testing workflow

* rename helm testing workflow to disable it

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-06 19:21:29 +00:00
pablodanswer
aeb6060854 Add ability to delete users (#2342)
* add ability to delete users

* fix tiny build issue

* Add comments
2024-09-06 17:37:04 +00:00
hagen-danswer
8977b1b5fc Paginate connector page (#2328)
* Added pagination to individual connector pages

* I cooked

* Gordon Ramsay in this b

* meepe

* properly calculated max chunk and switch dict to array

* chunks -> batches

* increased max page size

* renmaed var
2024-09-06 17:00:25 +00:00
pablodanswer
69c0419146 Updated refreshing (#2327)
* clean up + add environment variables

* remove log

* update

* update api settings

* somewhat cleaner refresh functionality

* fully functional

* update settings

* validated

* remove random logs

* remove unneeded paramter + log

* move to ee + remove comments

* Cleanup unused

---------

Co-authored-by: Weves <chrisweaver101@gmail.com>
2024-09-06 04:36:55 +00:00
pablodanswer
2bd3833c55 Update search settings + chat/search handling (#2333)
* validate web list

* update search settings + chat/search handling

* remove accidentally added search manager

* minor build fix

* push from local
2024-09-06 00:07:39 +00:00
rkuo-danswer
2d7b312e6c harden indexing-status endpoint against db changes happening in the background. Needs further improvement but OK for now. (#2338) 2024-09-05 20:09:33 +00:00
pablodanswer
ebe3674ca7 update for edge case (#2336) 2024-09-05 17:58:49 +00:00
pablodanswer
04f83eb1e1 Proper popover behavior, no showing queries with no docs, + bubbles (#2330) 2024-09-04 21:26:19 -07:00
pablodanswer
420aabc963 Update UX (#2324) 2024-09-04 18:45:52 -07:00
pablodanswer
61a17319c9 rename directory if needed 2024-09-04 17:22:59 -07:00
hagen-danswer
e4c85352b4 made connectors summary page faster (#2320)
* made connectors summary page faster

* not worth risk
2024-09-04 23:25:45 +00:00
pablodanswer
34ba3181ff Update auth for litellm proxy (#2316)
* update for auth

* validated embedding model names

* remove embedding provider

* remove logs

* add ability to delete search setting

* add abiility to delete models + more streamlined API endpoints

* remove upsert

* minor typing fix

* add connector utils
2024-09-04 20:59:07 +00:00
rkuo-danswer
630e2248bd fixing a race condition in celery task wrapper. could randomly blow up any task. (#2321) 2024-09-04 04:17:29 +00:00
hagen-danswer
c358c91e4c Added instance domain to telemetry (#2310) 2024-09-03 21:04:40 -07:00
Yuhong Sun
2b7915f33b Update Connector README PATH (#2323) 2024-09-03 20:56:37 -07:00
pablodanswer
0ff1a023cd Minor search setting clarity (#2300)
* minor search setting clarity

* 5433

* squash

* remove logs
2024-09-03 20:48:34 -07:00
Yuhong Sun
d68d281e1c Slight copy update (#2322) 2024-09-03 20:14:03 -07:00
hagen-danswer
ebce3ff6ba added wait for sync after creating document set in tests (#2319) 2024-09-04 00:34:40 +00:00
pablodanswer
f96bd12ab8 prevent accidental submission (#2318) 2024-09-03 16:44:54 -07:00
pablodanswer
32359d2dff Add user dropdown seed-able list (#2308)
* add user dropdown seedable list

* minor cleanup

* fix build issue

* minor type update

* remove log

* quick update to divider logic (squash)

* tiny icon updates
2024-09-03 19:24:50 +00:00
Chris Weaver
5da6d792de Add ingestion as a "Source" for the FE + improve typing (#2312) 2024-09-03 12:34:31 -07:00
pablodanswer
fb95398e5b Cleaner stream handling in Answer class (#2314)
* add cleaner stream

* add cleaner stream handling
2024-09-03 18:36:01 +00:00
rkuo-danswer
af66650ee3 fail safely if lookup for document fails (#2309) 2024-09-03 10:01:17 -07:00
pablodanswer
5b1f3c8d4e Formatting nits (#2311)
* stream in all cases

* update code block

* code formatting nits

* proper ports

* proper ports

* remove unnecessary lines
2024-09-03 16:05:02 +00:00
hagen-danswer
a3b1b1db38 fixed doc set table (#2306) 2024-09-03 15:36:07 +00:00
Weves
7520fae068 Add back test 2024-09-02 18:04:55 -07:00
Weves
39c946536c Fix deletion due to foreign key issue 2024-09-02 17:56:43 -07:00
Yuhong Sun
90528ba195 k 2024-09-02 17:33:33 -07:00
pablodanswer
6afcaafe54 Continue Generating (#2286)
* add stop reason

* add initial propagation

* add continue generating full functionality

* proper continue across chat session

* add new look

* propagate proper types

* fix typing

* cleaner continue generating functionality

* update types

* remove unused imports

* proper infodump

* temp

* add standardized stream handling

* validateing chosen tool args

* properly handle tools

* proper ports

* remove logs + build

* minor typing fix

* fix more minor typing issues

* add stashed reversion for tool call chunks

* ignore model dump types

* remove stop stream

* fix typing
2024-09-02 22:49:56 +00:00
Yuhong Sun
812ca69949 Vespa Degraded Handling (#2304) 2024-09-02 15:53:37 -07:00
rkuo-danswer
abe01144ca Update CONTRIBUTING.md (#2298) 2024-09-02 15:30:18 -07:00
Yuhong Sun
d988a3e736 Productboard Minor Fix (#2303) 2024-09-02 14:46:35 -07:00
pablodanswer
2b14afe878 Add proper typing such that tests pass mypy (#2301)
* add proper typing such that tests pass mypy

* nit (squash)

* minor update
2024-09-02 21:03:53 +00:00
Chris Weaver
033ec0b6b1 Remove unused env variables (#2299) 2024-09-02 20:29:14 +00:00
pablodanswer
14a9fecc64 update code block (#2297) 2024-09-02 13:33:18 -07:00
Weves
0027f161d7 Fix revisions 2024-09-02 11:13:55 -07:00
Yuhong Sun
32e551b69c Vespa Log No Response (#2295) 2024-09-02 09:14:28 -07:00
pablodanswer
299cb5035c Add litellm proxy embeddings (#2291)
* add litellm proxy

* formatting

* move `api_url` to cloud provider + nits

* remove log

* typing

* quick tuyping fix

* update LiteLLM selection logic

* remove logs + validate functionality

* rename proxy var

* update path casing

* remove pricing for custom models

* functional values
2024-09-02 09:08:35 -07:00
pablodanswer
910821c723 Ordered indexing status (#2292) 2024-09-02 08:39:18 -07:00
hagen-danswer
aa84846298 Connector deletion fix (#2293)
---------

Co-authored-by: Weves <chrisweaver101@gmail.com>
2024-09-01 23:32:20 -07:00
pablodanswer
c122be2f6a More explicit Confluence Connector (#2289) 2024-09-01 20:35:29 -07:00
Weves
f871b4c6eb Update default anthropic / bedrock models 2024-09-01 20:35:00 -07:00
hagen-danswer
a96cea2ce0 logging improvements 2024-09-01 16:21:35 -07:00
hagen-danswer
8d443ada5b Integration tests (#2256)
* initial commit

* almost done

* finished 3 tests

* minor refactor

* built out initial permisison tests

* reworked test_deletion

* removed logging

* all original tests have been converted

* renamed user_groups to user_group

* mypy

* added test for doc set permissions

* unified naming for manager methods

* Refactored models and added new deletion test

* minor additions

* better logging+fixed input variables

* commented out failed tests

* Added readme

* readme update

* Added auth to IT

set auth_type to basic and require_email_verification to false

* Update run-it.yml

* used verify and added to readme

* added api key manager
2024-09-01 22:21:00 +00:00
pablodanswer
634de83d72 Very minor update to divider logic (#2287) 2024-08-31 14:40:15 -07:00
Yuhong Sun
580848cf8c mypy (#2283) 2024-08-30 18:02:18 -07:00
Yuhong Sun
f01027cfb7 Catch LLM Eval Failures (#2272) 2024-08-30 17:42:58 -07:00
pablodanswer
76db4b765a Detect GPU on startup for default multi-pass indexing value (#2242) 2024-08-30 17:38:31 -07:00
482 changed files with 19979 additions and 7818 deletions

View File

@@ -0,0 +1,76 @@
name: 'Build and Push Docker Image with Retry'
description: 'Attempts to build and push a Docker image, with a retry on failure'
inputs:
context:
description: 'Build context'
required: true
file:
description: 'Dockerfile location'
required: true
platforms:
description: 'Target platforms'
required: true
pull:
description: 'Always attempt to pull a newer version of the image'
required: false
default: 'true'
push:
description: 'Push the image to registry'
required: false
default: 'true'
load:
description: 'Load the image into Docker daemon'
required: false
default: 'true'
tags:
description: 'Image tags'
required: true
cache-from:
description: 'Cache sources'
required: false
cache-to:
description: 'Cache destinations'
required: false
retry-wait-time:
description: 'Time to wait before retry in seconds'
required: false
default: '5'
runs:
using: "composite"
steps:
- name: Build and push Docker image (First Attempt)
id: buildx1
uses: docker/build-push-action@v5
continue-on-error: true
with:
context: ${{ inputs.context }}
file: ${{ inputs.file }}
platforms: ${{ inputs.platforms }}
pull: ${{ inputs.pull }}
push: ${{ inputs.push }}
load: ${{ inputs.load }}
tags: ${{ inputs.tags }}
cache-from: ${{ inputs.cache-from }}
cache-to: ${{ inputs.cache-to }}
- name: Wait to retry
if: steps.buildx1.outcome != 'success'
run: |
echo "First attempt failed. Waiting ${{ inputs.retry-wait-time }} seconds before retry..."
sleep ${{ inputs.retry-wait-time }}
shell: bash
- name: Build and push Docker image (Retry Attempt)
if: steps.buildx1.outcome != 'success'
uses: docker/build-push-action@v5
with:
context: ${{ inputs.context }}
file: ${{ inputs.file }}
platforms: ${{ inputs.platforms }}
pull: ${{ inputs.pull }}
push: ${{ inputs.push }}
load: ${{ inputs.load }}
tags: ${{ inputs.tags }}
cache-from: ${{ inputs.cache-from }}
cache-to: ${{ inputs.cache-to }}

View File

@@ -27,6 +27,11 @@ jobs:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Install build-essential
run: |
sudo apt-get update
sudo apt-get install -y build-essential
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v5
with:

View File

@@ -0,0 +1,67 @@
# This workflow is intentionally disabled while we're still working on it
# It's close to ready, but a race condition needs to be fixed with
# API server and Vespa startup, and it needs to have a way to build/test against
# local containers
name: Helm - Lint and Test Charts
on:
merge_group:
pull_request:
branches: [ main ]
jobs:
lint-test:
runs-on: Amd64
# fetch-depth 0 is required for helm/chart-testing-action
steps:
- name: Checkout code
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Set up Helm
uses: azure/setup-helm@v4.2.0
with:
version: v3.14.4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.6.1
- name: Run chart-testing (list-changed)
id: list-changed
run: |
changed=$(ct list-changed --target-branch ${{ github.event.repository.default_branch }})
if [[ -n "$changed" ]]; then
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
- name: Run chart-testing (lint)
# if: steps.list-changed.outputs.changed == 'true'
run: ct lint --all --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
- name: Create kind cluster
# if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@v1.10.0
- name: Run chart-testing (install)
# if: steps.list-changed.outputs.changed == 'true'
run: ct install --all --config ct.yaml
# run: ct install --target-branch ${{ github.event.repository.default_branch }}

View File

@@ -24,9 +24,9 @@ jobs:
backend/requirements/model_server.txt
- run: |
python -m pip install --upgrade pip
pip install -r backend/requirements/default.txt
pip install -r backend/requirements/dev.txt
pip install -r backend/requirements/model_server.txt
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
- name: Run MyPy
run: |

View File

@@ -10,6 +10,9 @@ on:
env:
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
@@ -36,8 +39,8 @@ jobs:
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
pip install -r backend/requirements/default.txt
pip install -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"

View File

@@ -11,7 +11,8 @@ jobs:
env:
PYTHONPATH: ./backend
REDIS_CLOUD_PYTEST_PASSWORD: ${{ secrets.REDIS_CLOUD_PYTEST_PASSWORD }}
steps:
- name: Checkout code
uses: actions/checkout@v4
@@ -28,8 +29,8 @@ jobs:
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
pip install -r backend/requirements/default.txt
pip install -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"

View File

@@ -13,8 +13,7 @@ env:
jobs:
integration-tests:
runs-on:
group: 'arm64-image-builders'
runs-on: Amd64
steps:
- name: Checkout code
uses: actions/checkout@v4
@@ -28,30 +27,20 @@ jobs:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build Web Docker image
uses: docker/build-push-action@v5
with:
context: ./web
file: ./web/Dockerfile
platforms: linux/arm64
pull: true
push: true
load: true
tags: danswer/danswer-web-server:it
cache-from: type=registry,ref=danswer/danswer-web-server:it
cache-to: |
type=registry,ref=danswer/danswer-web-server:it,mode=max
type=inline
# NOTE: we don't need to build the Web Docker image since it's not used
# during the IT for now. We have a separate action to verify it builds
# succesfully
- name: Pull Web Docker image
run: |
docker pull danswer/danswer-web-server:latest
docker tag danswer/danswer-web-server:latest danswer/danswer-web-server:it
- name: Build Backend Docker image
uses: docker/build-push-action@v5
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/arm64
pull: true
push: true
load: true
platforms: linux/amd64
tags: danswer/danswer-backend:it
cache-from: type=registry,ref=danswer/danswer-backend:it
cache-to: |
@@ -59,14 +48,11 @@ jobs:
type=inline
- name: Build Model Server Docker image
uses: docker/build-push-action@v5
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/arm64
pull: true
push: true
load: true
platforms: linux/amd64
tags: danswer/danswer-model-server:it
cache-from: type=registry,ref=danswer/danswer-model-server:it
cache-to: |
@@ -74,14 +60,11 @@ jobs:
type=inline
- name: Build integration test Docker image
uses: docker/build-push-action@v5
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/tests/integration/Dockerfile
platforms: linux/arm64
pull: true
push: true
load: true
platforms: linux/amd64
tags: danswer/integration-test-runner:it
cache-from: type=registry,ref=danswer/integration-test-runner:it
cache-to: |
@@ -92,14 +75,19 @@ jobs:
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=it \
docker compose -f docker-compose.dev.yml -p danswer-stack up -d --build
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
docker logs -f danswer-stack-api_server-1 &
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
@@ -137,6 +125,7 @@ jobs:
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
danswer/integration-test-runner:it

2
.gitignore vendored
View File

@@ -4,6 +4,6 @@
.mypy_cache
.idea
/deployment/data/nginx/app.conf
.vscode/launch.json
.vscode/
*.sw?
/backend/tests/regression/answer_quality/search_test_config.yaml

View File

@@ -1,5 +1,5 @@
# Copy this file to .env at the base of the repo and fill in the <REPLACE THIS> values
# This will help with development iteration speed and reduce repeat tasks for dev
# Copy this file to .env in the .vscode folder
# Fill in the <REPLACE THIS> values as needed, it is recommended to set the GEN_AI_API_KEY value to avoid having to set up an LLM in the UI
# Also check out danswer/backend/scripts/restart_containers.sh for a script to restart the containers which Danswer relies on outside of VSCode/Cursor processes
# For local dev, often user Authentication is not needed
@@ -15,7 +15,7 @@ LOG_LEVEL=debug
# This passes top N results to LLM an additional time for reranking prior to answer generation
# This step is quite heavy on token usage so we disable it for dev generally
DISABLE_LLM_DOC_RELEVANCE=True
DISABLE_LLM_DOC_RELEVANCE=False
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)
@@ -27,9 +27,9 @@ REQUIRE_EMAIL_VERIFICATION=False
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
GEN_AI_API_KEY=<REPLACE THIS>
# If answer quality isn't important for dev, use 3.5 turbo due to it being cheaper
GEN_AI_MODEL_VERSION=gpt-3.5-turbo
FAST_GEN_AI_MODEL_VERSION=gpt-3.5-turbo
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper
GEN_AI_MODEL_VERSION=gpt-4o
FAST_GEN_AI_MODEL_VERSION=gpt-4o
# For Danswer Slack Bot, overrides the UI values so no need to set this up via UI every time
# Only needed if using DanswerBot
@@ -38,7 +38,7 @@ FAST_GEN_AI_MODEL_VERSION=gpt-3.5-turbo
# Python stuff
PYTHONPATH=./backend
PYTHONPATH=../backend
PYTHONUNBUFFERED=1
@@ -49,4 +49,3 @@ BING_API_KEY=<REPLACE THIS>
# Enable the full set of Danswer Enterprise Edition features
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False

View File

@@ -1,15 +1,23 @@
/*
Copy this file into '.vscode/launch.json' or merge its
contents into your existing configurations.
*/
/* Copy this file into '.vscode/launch.json' or merge its contents into your existing configurations. */
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"compounds": [
{
"name": "Run All Danswer Services",
"configurations": [
"Web Server",
"Model Server",
"API Server",
"Indexing",
"Background Jobs",
"Slack Bot"
]
}
],
"configurations": [
{
"name": "Web Server",
@@ -17,7 +25,7 @@
"request": "launch",
"cwd": "${workspaceRoot}/web",
"runtimeExecutable": "npm",
"envFile": "${workspaceFolder}/.env",
"envFile": "${workspaceFolder}/.vscode/.env",
"runtimeArgs": [
"run", "dev"
],
@@ -25,11 +33,12 @@
},
{
"name": "Model Server",
"type": "python",
"consoleName": "Model Server",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
@@ -39,16 +48,16 @@
"--reload",
"--port",
"9000"
],
"consoleTitle": "Model Server"
]
},
{
"name": "API Server",
"type": "python",
"consoleName": "API Server",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
@@ -59,32 +68,32 @@
"--reload",
"--port",
"8080"
],
"consoleTitle": "API Server"
]
},
{
"name": "Indexing",
"type": "python",
"consoleName": "Indexing",
"type": "debugpy",
"request": "launch",
"program": "danswer/background/update.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"consoleTitle": "Indexing"
}
},
// Celery and all async jobs, usually would include indexing as well but this is handled separately above for dev
{
"name": "Background Jobs",
"type": "python",
"consoleName": "Background Jobs",
"type": "debugpy",
"request": "launch",
"program": "scripts/dev_run_background_jobs.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
@@ -93,18 +102,18 @@
},
"args": [
"--no-indexing"
],
"consoleTitle": "Background Jobs"
]
},
// For the listner to access the Slack API,
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
{
"name": "Slack Bot",
"type": "python",
"consoleName": "Slack Bot",
"type": "debugpy",
"request": "launch",
"program": "danswer/danswerbot/slack/listener.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
@@ -113,11 +122,12 @@
},
{
"name": "Pytest",
"type": "python",
"consoleName": "Pytest",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
@@ -128,18 +138,16 @@
// Specify a sepcific module/test to run or provide nothing to run all tests
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
]
}
],
"compounds": [
},
{
"name": "Run Danswer",
"configurations": [
"Web Server",
"Model Server",
"API Server",
"Indexing",
"Background Jobs",
]
"name": "Clear and Restart External Volumes and Containers",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"stopOnEntry": true
}
]
}

View File

@@ -48,23 +48,26 @@ We would love to see you there!
## Get Started 🚀
Danswer being a fully functional app, relies on some external pieces of software, specifically:
Danswer being a fully functional app, relies on some external software, specifically:
- [Postgres](https://www.postgresql.org/) (Relational DB)
- [Vespa](https://vespa.ai/) (Vector DB/Search Engine)
- [Redis](https://redis.io/) (Cache)
- [Nginx](https://nginx.org/) (Not needed for development flows generally)
This guide provides instructions to set up the Danswer specific services outside of Docker because it's easier for
development purposes but also feel free to just use the containers and update with local changes by providing the
`--build` flag.
> **Note:**
> This guide provides instructions to build and run Danswer locally from source with Docker containers providing the above external software. We believe this combination is easier for
> development purposes. If you prefer to use pre-built container images, we provide instructions on running the full Danswer stack within Docker below.
### Local Set Up
It is recommended to use Python version 3.11
Be sure to use Python version 3.11. For instructions on installing Python 3.11 on macOS, refer to the [CONTRIBUTING_MACOS.md](./CONTRIBUTING_MACOS.md) readme.
If using a lower version, modifications will have to be made to the code.
If using a higher version, the version of Tensorflow we use may not be available for your platform.
If using a higher version, sometimes some libraries will not be available (i.e. we had problems with Tensorflow in the past with higher versions of python).
#### Installing Requirements
#### Backend: Python requirements
Currently, we use pip and recommend creating a virtual environment.
For convenience here's a command for it:
@@ -73,8 +76,9 @@ python -m venv .venv
source .venv/bin/activate
```
--> Note that this virtual environment MUST NOT be set up WITHIN the danswer
directory
> **Note:**
> This virtual environment MUST NOT be set up WITHIN the danswer directory if you plan on using mypy within certain IDEs.
> For simplicity, we recommend setting up the virtual environment outside of the danswer directory.
_For Windows, activate the virtual environment using Command Prompt:_
```bash
@@ -89,34 +93,38 @@ Install the required python dependencies:
```bash
pip install -r danswer/backend/requirements/default.txt
pip install -r danswer/backend/requirements/dev.txt
pip install -r danswer/backend/requirements/ee.txt
pip install -r danswer/backend/requirements/model_server.txt
```
Install Playwright for Python (headless browser required by the Web Connector)
In the activated Python virtualenv, install Playwright for Python by running:
```bash
playwright install
```
You may have to deactivate and reactivate your virtualenv for `playwright` to appear on your path.
#### Frontend: Node dependencies
Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend.
Once the above is done, navigate to `danswer/web` run:
```bash
npm i
```
Install Playwright (required by the Web Connector)
#### Docker containers for external software
You will need Docker installed to run these containers.
> Note: If you have just done the pip install, open a new terminal and source the python virtual-env again.
This will update the path to include playwright
Then install Playwright by running:
First navigate to `danswer/deployment/docker_compose`, then start up Postgres/Vespa/Redis with:
```bash
playwright install
docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational_db cache
```
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
#### Dependent Docker Containers
First navigate to `danswer/deployment/docker_compose`, then start up Vespa and Postgres with:
```bash
docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational_db
```
(index refers to Vespa and relational_db refers to Postgres)
#### Running Danswer
#### Running Danswer locally
To start the frontend, navigate to `danswer/web` and run:
```bash
npm run dev
@@ -127,11 +135,10 @@ Navigate to `danswer/backend` and run:
```bash
uvicorn model_server.main:app --reload --port 9000
```
_For Windows (for compatibility with both PowerShell and Command Prompt):_
```bash
powershell -Command "
uvicorn model_server.main:app --reload --port 9000
"
powershell -Command "uvicorn model_server.main:app --reload --port 9000"
```
The first time running Danswer, you will need to run the DB migrations for Postgres.
@@ -154,6 +161,7 @@ To run the backend API server, navigate back to `danswer/backend` and run:
```bash
AUTH_TYPE=disabled uvicorn danswer.main:app --reload --port 8080
```
_For Windows (for compatibility with both PowerShell and Command Prompt):_
```bash
powershell -Command "
@@ -162,20 +170,58 @@ powershell -Command "
"
```
Note: if you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services.
> **Note:**
> If you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services.
#### Wrapping up
You should now have 4 servers running:
- Web server
- Backend API
- Model server
- Background jobs
Now, visit `http://localhost:3000` in your browser. You should see the Danswer onboarding wizard where you can connect your external LLM provider to Danswer.
You've successfully set up a local Danswer instance! 🏁
#### Running the Danswer application in a container
You can run the full Danswer application stack from pre-built images including all external software dependencies.
Navigate to `danswer/deployment/docker_compose` and run:
```bash
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
```
After Docker pulls and starts these containers, navigate to `http://localhost:3000` to use Danswer.
If you want to make changes to Danswer and run those changes in Docker, you can also build a local version of the Danswer container images that incorporates your changes like so:
```bash
docker compose -f docker-compose.dev.yml -p danswer-stack up -d --build
```
### Formatting and Linting
#### Backend
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
First, install pre-commit (if you don't have it already) following the instructions
[here](https://pre-commit.com/#installation).
With the virtual environment active, install the pre-commit library with:
```bash
pip install pre-commit
```
Then, from the `danswer/backend` directory, run:
```bash
pre-commit install
```
Additionally, we use `mypy` for static type checking.
Danswer is fully type-annotated, and we would like to keep it that way!
Danswer is fully type-annotated, and we want to keep it that way!
To run the mypy checks manually, run `python -m mypy .` from the `danswer/backend` directory.
@@ -186,6 +232,7 @@ Please double check that prettier passes before creating a pull request.
### Release Process
Danswer follows the semver versioning standard.
Danswer loosely follows the SemVer versioning standard.
Major changes are released with a "minor" version bump. Currently we use patch release versions to indicate small feature changes.
A set of Docker containers will be pushed automatically to DockerHub with every tag.
You can see the containers [here](https://hub.docker.com/search?q=danswer%2F).

31
CONTRIBUTING_MACOS.md Normal file
View File

@@ -0,0 +1,31 @@
## Some additional notes for Mac Users
The base instructions to set up the development environment are located in [CONTRIBUTING.md](https://github.com/danswer-ai/danswer/blob/main/CONTRIBUTING.md).
### Setting up Python
Ensure [Homebrew](https://brew.sh/) is already set up.
Then install python 3.11.
```bash
brew install python@3.11
```
Add python 3.11 to your path: add the following line to ~/.zshrc
```
export PATH="$(brew --prefix)/opt/python@3.11/libexec/bin:$PATH"
```
> **Note:**
> You will need to open a new terminal for the path change above to take effect.
### Setting up Docker
On macOS, you will need to install [Docker Desktop](https://www.docker.com/products/docker-desktop/) and
ensure it is running before continuing with the docker commands.
### Formatting and Linting
MacOS will likely require you to remove some quarantine attributes on some of the hooks for them to execute properly.
After installing pre-commit, run the following command:
```bash
sudo xattr -r -d com.apple.quarantine ~/.cache/pre-commit
```

View File

@@ -9,7 +9,8 @@ founders@danswer.ai for more information. Please visit https://github.com/danswe
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
ARG DANSWER_VERSION=0.3-dev
ENV DANSWER_VERSION=${DANSWER_VERSION}
ENV DANSWER_VERSION=${DANSWER_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true"
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
# Install system dependencies
@@ -40,6 +41,8 @@ RUN apt-get update && \
COPY ./requirements/default.txt /tmp/requirements.txt
COPY ./requirements/ee.txt /tmp/ee-requirements.txt
RUN pip install --no-cache-dir --upgrade \
--retries 5 \
--timeout 30 \
-r /tmp/requirements.txt \
-r /tmp/ee-requirements.txt && \
pip uninstall -y py && \
@@ -75,8 +78,8 @@ Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
# Pre-downloading NLTK for setups with limited egress
RUN python -c "import nltk; \
nltk.download('stopwords', quiet=True); \
nltk.download('wordnet', quiet=True); \
nltk.download('punkt', quiet=True);"
# nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed
# Set up application files
WORKDIR /app

View File

@@ -8,11 +8,17 @@ visit https://github.com/danswer-ai/danswer."
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
ARG DANSWER_VERSION=0.3-dev
ENV DANSWER_VERSION=${DANSWER_VERSION}
ENV DANSWER_VERSION=${DANSWER_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true"
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
COPY ./requirements/model_server.txt /tmp/requirements.txt
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt
RUN pip install --no-cache-dir --upgrade \
--retries 5 \
--timeout 30 \
-r /tmp/requirements.txt
RUN apt-get remove -y --allow-remove-essential perl-base && \
apt-get autoremove -y

View File

@@ -16,7 +16,9 @@ config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
if config.config_file_name is not None and config.attributes.get(
"configure_logger", True
):
fileConfig(config.config_file_name)
# add your model's MetaData object here

View File

@@ -0,0 +1,27 @@
"""add ccpair deletion failure message
Revision ID: 0ebb1d516877
Revises: 52a219fb5233
Create Date: 2024-09-10 15:03:48.233926
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0ebb1d516877"
down_revision = "52a219fb5233"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column("deletion_failure_message", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "deletion_failure_message")

View File

@@ -0,0 +1,102 @@
"""add_user_delete_cascades
Revision ID: 1b8206b29c5d
Revises: 35e6853a51d5
Create Date: 2024-09-18 11:48:59.418726
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "1b8206b29c5d"
down_revision = "35e6853a51d5"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_constraint("credential_user_id_fkey", "credential", type_="foreignkey")
op.create_foreign_key(
"credential_user_id_fkey",
"credential",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
op.drop_constraint("chat_session_user_id_fkey", "chat_session", type_="foreignkey")
op.create_foreign_key(
"chat_session_user_id_fkey",
"chat_session",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
op.drop_constraint("chat_folder_user_id_fkey", "chat_folder", type_="foreignkey")
op.create_foreign_key(
"chat_folder_user_id_fkey",
"chat_folder",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
op.drop_constraint("prompt_user_id_fkey", "prompt", type_="foreignkey")
op.create_foreign_key(
"prompt_user_id_fkey", "prompt", "user", ["user_id"], ["id"], ondelete="CASCADE"
)
op.drop_constraint("notification_user_id_fkey", "notification", type_="foreignkey")
op.create_foreign_key(
"notification_user_id_fkey",
"notification",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
op.drop_constraint("inputprompt_user_id_fkey", "inputprompt", type_="foreignkey")
op.create_foreign_key(
"inputprompt_user_id_fkey",
"inputprompt",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
def downgrade() -> None:
op.drop_constraint("credential_user_id_fkey", "credential", type_="foreignkey")
op.create_foreign_key(
"credential_user_id_fkey", "credential", "user", ["user_id"], ["id"]
)
op.drop_constraint("chat_session_user_id_fkey", "chat_session", type_="foreignkey")
op.create_foreign_key(
"chat_session_user_id_fkey", "chat_session", "user", ["user_id"], ["id"]
)
op.drop_constraint("chat_folder_user_id_fkey", "chat_folder", type_="foreignkey")
op.create_foreign_key(
"chat_folder_user_id_fkey", "chat_folder", "user", ["user_id"], ["id"]
)
op.drop_constraint("prompt_user_id_fkey", "prompt", type_="foreignkey")
op.create_foreign_key("prompt_user_id_fkey", "prompt", "user", ["user_id"], ["id"])
op.drop_constraint("notification_user_id_fkey", "notification", type_="foreignkey")
op.create_foreign_key(
"notification_user_id_fkey", "notification", "user", ["user_id"], ["id"]
)
op.drop_constraint("inputprompt_user_id_fkey", "inputprompt", type_="foreignkey")
op.create_foreign_key(
"inputprompt_user_id_fkey", "inputprompt", "user", ["user_id"], ["id"]
)

View File

@@ -30,7 +30,7 @@ def upgrade() -> None:
op.add_column(
"search_settings",
sa.Column(
"multipass_indexing", sa.Boolean(), nullable=False, server_default="true"
"multipass_indexing", sa.Boolean(), nullable=False, server_default="false"
),
)
op.add_column(

View File

@@ -0,0 +1,64 @@
"""server default chosen assistants
Revision ID: 35e6853a51d5
Revises: c99d76fcd298
Create Date: 2024-09-13 13:20:32.885317
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "35e6853a51d5"
down_revision = "c99d76fcd298"
branch_labels = None
depends_on = None
DEFAULT_ASSISTANTS = [-2, -1, 0]
def upgrade() -> None:
# Step 1: Update any NULL values to the default value
# This upgrades existing users without ordered assistant
# to have default assistants set to visible assistants which are
# accessible by them.
op.execute(
"""
UPDATE "user" u
SET chosen_assistants = (
SELECT jsonb_agg(
p.id ORDER BY
COALESCE(p.display_priority, 2147483647) ASC,
p.id ASC
)
FROM persona p
LEFT JOIN persona__user pu ON p.id = pu.persona_id AND pu.user_id = u.id
WHERE p.is_visible = true
AND (p.is_public = true OR pu.user_id IS NOT NULL)
)
WHERE chosen_assistants IS NULL
OR chosen_assistants = 'null'
OR jsonb_typeof(chosen_assistants) = 'null'
OR (jsonb_typeof(chosen_assistants) = 'string' AND chosen_assistants = '"null"')
"""
)
# Step 2: Alter the column to make it non-nullable
op.alter_column(
"user",
"chosen_assistants",
type_=postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default=sa.text(f"'{DEFAULT_ASSISTANTS}'::jsonb"),
)
def downgrade() -> None:
op.alter_column(
"user",
"chosen_assistants",
type_=postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
server_default=None,
)

View File

@@ -0,0 +1,66 @@
"""Add last synced and last modified to document table
Revision ID: 52a219fb5233
Revises: f7e58d357687
Create Date: 2024-08-28 17:40:46.077470
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import func
# revision identifiers, used by Alembic.
revision = "52a219fb5233"
down_revision = "f7e58d357687"
branch_labels = None
depends_on = None
def upgrade() -> None:
# last modified represents the last time anything needing syncing to vespa changed
# including row metadata and the document itself. This obviously does not include
# the last_synced column.
op.add_column(
"document",
sa.Column(
"last_modified",
sa.DateTime(timezone=True),
nullable=False,
server_default=func.now(),
),
)
# last synced represents the last time this document was synced to Vespa
op.add_column(
"document",
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True),
)
# Set last_synced to the same value as last_modified for existing rows
op.execute(
"""
UPDATE document
SET last_synced = last_modified
"""
)
op.create_index(
op.f("ix_document_last_modified"),
"document",
["last_modified"],
unique=False,
)
op.create_index(
op.f("ix_document_last_synced"),
"document",
["last_synced"],
unique=False,
)
def downgrade() -> None:
op.drop_index(op.f("ix_document_last_synced"), table_name="document")
op.drop_index(op.f("ix_document_last_modified"), table_name="document")
op.drop_column("document", "last_synced")
op.drop_column("document", "last_modified")

View File

@@ -0,0 +1,79 @@
"""assistant_rework
Revision ID: 55546a7967ee
Revises: 61ff3651add4
Create Date: 2024-09-18 17:00:23.755399
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "55546a7967ee"
down_revision = "61ff3651add4"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Reworking persona and user tables for new assistant features
# keep track of user's chosen assistants separate from their `ordering`
op.add_column("persona", sa.Column("builtin_persona", sa.Boolean(), nullable=True))
op.execute("UPDATE persona SET builtin_persona = default_persona")
op.alter_column("persona", "builtin_persona", nullable=False)
op.drop_index("_default_persona_name_idx", table_name="persona")
op.create_index(
"_builtin_persona_name_idx",
"persona",
["name"],
unique=True,
postgresql_where=sa.text("builtin_persona = true"),
)
op.add_column(
"user", sa.Column("visible_assistants", postgresql.JSONB(), nullable=True)
)
op.add_column(
"user", sa.Column("hidden_assistants", postgresql.JSONB(), nullable=True)
)
op.execute(
"UPDATE \"user\" SET visible_assistants = '[]'::jsonb, hidden_assistants = '[]'::jsonb"
)
op.alter_column(
"user",
"visible_assistants",
nullable=False,
server_default=sa.text("'[]'::jsonb"),
)
op.alter_column(
"user",
"hidden_assistants",
nullable=False,
server_default=sa.text("'[]'::jsonb"),
)
op.drop_column("persona", "default_persona")
op.add_column(
"persona", sa.Column("is_default_persona", sa.Boolean(), nullable=True)
)
def downgrade() -> None:
# Reverting changes made in upgrade
op.drop_column("user", "hidden_assistants")
op.drop_column("user", "visible_assistants")
op.drop_index("_builtin_persona_name_idx", table_name="persona")
op.drop_column("persona", "is_default_persona")
op.add_column("persona", sa.Column("default_persona", sa.Boolean(), nullable=True))
op.execute("UPDATE persona SET default_persona = builtin_persona")
op.alter_column("persona", "default_persona", nullable=False)
op.drop_column("persona", "builtin_persona")
op.create_index(
"_default_persona_name_idx",
"persona",
["name"],
unique=True,
postgresql_where=sa.text("default_persona = true"),
)

View File

@@ -0,0 +1,35 @@
"""match_any_keywords flag for standard answers
Revision ID: 5c7fdadae813
Revises: efb35676026c
Create Date: 2024-09-13 18:52:59.256478
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "5c7fdadae813"
down_revision = "efb35676026c"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"standard_answer",
sa.Column(
"match_any_keywords",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("standard_answer", "match_any_keywords")
# ### end Alembic commands ###

View File

@@ -0,0 +1,162 @@
"""Add Permission Syncing
Revision ID: 61ff3651add4
Revises: 1b8206b29c5d
Create Date: 2024-09-05 13:57:11.770413
"""
import fastapi_users_db_sqlalchemy
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "61ff3651add4"
down_revision = "1b8206b29c5d"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Admin user who set up connectors will lose access to the docs temporarily
# only way currently to give back access is to rerun from beginning
op.add_column(
"connector_credential_pair",
sa.Column(
"access_type",
sa.String(),
nullable=True,
),
)
op.execute(
"UPDATE connector_credential_pair SET access_type = 'PUBLIC' WHERE is_public = true"
)
op.execute(
"UPDATE connector_credential_pair SET access_type = 'PRIVATE' WHERE is_public = false"
)
op.alter_column("connector_credential_pair", "access_type", nullable=False)
op.add_column(
"connector_credential_pair",
sa.Column(
"auto_sync_options",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
op.add_column(
"connector_credential_pair",
sa.Column("last_time_perm_sync", sa.DateTime(timezone=True), nullable=True),
)
op.drop_column("connector_credential_pair", "is_public")
op.add_column(
"document",
sa.Column("external_user_emails", postgresql.ARRAY(sa.String()), nullable=True),
)
op.add_column(
"document",
sa.Column(
"external_user_group_ids", postgresql.ARRAY(sa.String()), nullable=True
),
)
op.add_column(
"document",
sa.Column("is_public", sa.Boolean(), nullable=True),
)
op.create_table(
"user__external_user_group_id",
sa.Column(
"user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False
),
sa.Column("external_user_group_id", sa.String(), nullable=False),
sa.Column("cc_pair_id", sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint("user_id"),
)
op.drop_column("external_permission", "user_id")
op.drop_column("email_to_external_user_cache", "user_id")
op.drop_table("permission_sync_run")
op.drop_table("external_permission")
op.drop_table("email_to_external_user_cache")
def downgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column("is_public", sa.BOOLEAN(), nullable=True),
)
op.execute(
"UPDATE connector_credential_pair SET is_public = (access_type = 'PUBLIC')"
)
op.alter_column("connector_credential_pair", "is_public", nullable=False)
op.drop_column("connector_credential_pair", "auto_sync_options")
op.drop_column("connector_credential_pair", "access_type")
op.drop_column("connector_credential_pair", "last_time_perm_sync")
op.drop_column("document", "external_user_emails")
op.drop_column("document", "external_user_group_ids")
op.drop_column("document", "is_public")
op.drop_table("user__external_user_group_id")
# Drop the enum type at the end of the downgrade
op.create_table(
"permission_sync_run",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"source_type",
sa.String(),
nullable=False,
),
sa.Column("update_type", sa.String(), nullable=False),
sa.Column("cc_pair_id", sa.Integer(), nullable=True),
sa.Column(
"status",
sa.String(),
nullable=False,
),
sa.Column("error_msg", sa.Text(), nullable=True),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["cc_pair_id"],
["connector_credential_pair.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"external_permission",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=True),
sa.Column("user_email", sa.String(), nullable=False),
sa.Column(
"source_type",
sa.String(),
nullable=False,
),
sa.Column("external_permission_group", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"email_to_external_user_cache",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("external_user_id", sa.String(), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=True),
sa.Column("user_email", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)

View File

@@ -0,0 +1,27 @@
"""persona_start_date
Revision ID: 797089dfb4d2
Revises: 55546a7967ee
Create Date: 2024-09-11 14:51:49.785835
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "797089dfb4d2"
down_revision = "55546a7967ee"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"persona",
sa.Column("search_start_date", sa.DateTime(timezone=True), nullable=True),
)
def downgrade() -> None:
op.drop_column("persona", "search_start_date")

View File

@@ -0,0 +1,158 @@
"""migration confluence to be explicit
Revision ID: a3795dce87be
Revises: 1f60f60c3401
Create Date: 2024-09-01 13:52:12.006740
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy.sql import table, column
revision = "a3795dce87be"
down_revision = "1f60f60c3401"
branch_labels: None = None
depends_on: None = None
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]:
from urllib.parse import urlparse
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]:
parsed_url = urlparse(wiki_url)
wiki_base = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split('/spaces')[0]}"
path_parts = parsed_url.path.split("/")
space = path_parts[3]
page_id = path_parts[5] if len(path_parts) > 5 else ""
return wiki_base, space, page_id
def _extract_confluence_keys_from_datacenter_url(
wiki_url: str,
) -> tuple[str, str, str]:
DISPLAY = "/display/"
PAGE = "/pages/"
parsed_url = urlparse(wiki_url)
wiki_base = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split(DISPLAY)[0]}"
space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0]
page_id = ""
if (content := parsed_url.path.split(PAGE)) and len(content) > 1:
page_id = content[1]
return wiki_base, space, page_id
is_confluence_cloud = (
".atlassian.net/wiki/spaces/" in wiki_url
or ".jira.com/wiki/spaces/" in wiki_url
)
if is_confluence_cloud:
wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url(wiki_url)
else:
wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url(
wiki_url
)
return wiki_base, space, page_id, is_confluence_cloud
def reconstruct_confluence_url(
wiki_base: str, space: str, page_id: str, is_cloud: bool
) -> str:
if is_cloud:
url = f"{wiki_base}/spaces/{space}"
if page_id:
url += f"/pages/{page_id}"
else:
url = f"{wiki_base}/display/{space}"
if page_id:
url += f"/pages/{page_id}"
return url
def upgrade() -> None:
connector = table(
"connector",
column("id", sa.Integer),
column("source", sa.String()),
column("input_type", sa.String()),
column("connector_specific_config", postgresql.JSONB),
)
# Fetch all Confluence connectors
connection = op.get_bind()
confluence_connectors = connection.execute(
sa.select(connector).where(
sa.and_(
connector.c.source == "CONFLUENCE", connector.c.input_type == "POLL"
)
)
).fetchall()
for row in confluence_connectors:
config = row.connector_specific_config
wiki_page_url = config["wiki_page_url"]
wiki_base, space, page_id, is_cloud = extract_confluence_keys_from_url(
wiki_page_url
)
new_config = {
"wiki_base": wiki_base,
"space": space,
"page_id": page_id,
"is_cloud": is_cloud,
}
for key, value in config.items():
if key not in ["wiki_page_url"]:
new_config[key] = value
op.execute(
connector.update()
.where(connector.c.id == row.id)
.values(connector_specific_config=new_config)
)
def downgrade() -> None:
connector = table(
"connector",
column("id", sa.Integer),
column("source", sa.String()),
column("input_type", sa.String()),
column("connector_specific_config", postgresql.JSONB),
)
confluence_connectors = (
op.get_bind()
.execute(
sa.select(connector).where(
connector.c.source == "CONFLUENCE", connector.c.input_type == "POLL"
)
)
.fetchall()
)
for row in confluence_connectors:
config = row.connector_specific_config
if all(key in config for key in ["wiki_base", "space", "is_cloud"]):
wiki_page_url = reconstruct_confluence_url(
config["wiki_base"],
config["space"],
config.get("page_id", ""),
config["is_cloud"],
)
new_config = {"wiki_page_url": wiki_page_url}
new_config.update(
{
k: v
for k, v in config.items()
if k not in ["wiki_base", "space", "page_id", "is_cloud"]
}
)
op.execute(
connector.update()
.where(connector.c.id == row.id)
.values(connector_specific_config=new_config)
)

View File

@@ -0,0 +1,26 @@
"""add support for litellm proxy in reranking
Revision ID: ba98eba0f66a
Revises: bceb1e139447
Create Date: 2024-09-06 10:36:04.507332
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "ba98eba0f66a"
down_revision = "bceb1e139447"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"search_settings", sa.Column("rerank_api_url", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("search_settings", "rerank_api_url")

View File

@@ -0,0 +1,26 @@
"""Add base_url to CloudEmbeddingProvider
Revision ID: bceb1e139447
Revises: a3795dce87be
Create Date: 2024-08-28 17:00:52.554580
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "bceb1e139447"
down_revision = "a3795dce87be"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"embedding_provider", sa.Column("api_url", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("embedding_provider", "api_url")

View File

@@ -0,0 +1,43 @@
"""non nullable default persona
Revision ID: bd2921608c3a
Revises: 797089dfb4d2
Create Date: 2024-09-20 10:28:37.992042
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "bd2921608c3a"
down_revision = "797089dfb4d2"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Set existing NULL values to False
op.execute(
"UPDATE persona SET is_default_persona = FALSE WHERE is_default_persona IS NULL"
)
# Alter the column to be not nullable with a default value of False
op.alter_column(
"persona",
"is_default_persona",
existing_type=sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
)
def downgrade() -> None:
# Revert the changes
op.alter_column(
"persona",
"is_default_persona",
existing_type=sa.Boolean(),
nullable=True,
server_default=None,
)

View File

@@ -0,0 +1,31 @@
"""add nullable to persona id in Chat Session
Revision ID: c99d76fcd298
Revises: 5c7fdadae813
Create Date: 2024-07-09 19:27:01.579697
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c99d76fcd298"
down_revision = "5c7fdadae813"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column(
"chat_session", "persona_id", existing_type=sa.INTEGER(), nullable=True
)
def downgrade() -> None:
op.alter_column(
"chat_session",
"persona_id",
existing_type=sa.INTEGER(),
nullable=False,
)

View File

@@ -0,0 +1,32 @@
"""standard answer match_regex flag
Revision ID: efb35676026c
Revises: 0ebb1d516877
Create Date: 2024-09-11 13:55:46.101149
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "efb35676026c"
down_revision = "0ebb1d516877"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"standard_answer",
sa.Column(
"match_regex", sa.Boolean(), nullable=False, server_default=sa.false()
),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("standard_answer", "match_regex")
# ### end Alembic commands ###

View File

@@ -0,0 +1,26 @@
"""add custom headers to tools
Revision ID: f32615f71aeb
Revises: bd2921608c3a
Create Date: 2024-09-12 20:26:38.932377
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "f32615f71aeb"
down_revision = "bd2921608c3a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"tool", sa.Column("custom_headers", postgresql.JSONB(), nullable=True)
)
def downgrade() -> None:
op.drop_column("tool", "custom_headers")

View File

@@ -0,0 +1,26 @@
"""add has_web_login column to user
Revision ID: f7e58d357687
Revises: ba98eba0f66a
Create Date: 2024-09-07 20:20:54.522620
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f7e58d357687"
down_revision = "ba98eba0f66a"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column("has_web_login", sa.Boolean(), nullable=False, server_default="true"),
)
def downgrade() -> None:
op.drop_column("user", "has_web_login")

View File

@@ -1,26 +1,81 @@
from sqlalchemy.orm import Session
from danswer.access.models import DocumentAccess
from danswer.access.utils import prefix_user
from danswer.access.utils import prefix_user_email
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.db.document import get_acccess_info_for_documents
from danswer.db.document import get_access_info_for_document
from danswer.db.document import get_access_info_for_documents
from danswer.db.models import User
from danswer.utils.variable_functionality import fetch_versioned_implementation
def _get_access_for_document(
document_id: str,
db_session: Session,
) -> DocumentAccess:
info = get_access_info_for_document(
db_session=db_session,
document_id=document_id,
)
return DocumentAccess.build(
user_emails=info[1] if info and info[1] else [],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=info[2] if info else False,
)
def get_access_for_document(
document_id: str,
db_session: Session,
) -> DocumentAccess:
versioned_get_access_for_document_fn = fetch_versioned_implementation(
"danswer.access.access", "_get_access_for_document"
)
return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore
def get_null_document_access() -> DocumentAccess:
return DocumentAccess(
user_emails=set(),
user_groups=set(),
is_public=False,
external_user_emails=set(),
external_user_group_ids=set(),
)
def _get_access_for_documents(
document_ids: list[str],
db_session: Session,
) -> dict[str, DocumentAccess]:
document_access_info = get_acccess_info_for_documents(
document_access_info = get_access_info_for_documents(
db_session=db_session,
document_ids=document_ids,
)
return {
document_id: DocumentAccess.build(user_ids, [], is_public)
for document_id, user_ids, is_public in document_access_info
doc_access = {
document_id: DocumentAccess(
user_emails=set([email for email in user_emails if email]),
# MIT version will wipe all groups and external groups on update
user_groups=set(),
is_public=is_public,
external_user_emails=set(),
external_user_group_ids=set(),
)
for document_id, user_emails, is_public in document_access_info
}
# Sometimes the document has not be indexed by the indexing job yet, in those cases
# the document does not exist and so we use least permissive. Specifically the EE version
# checks the MIT version permissions and creates a superset. This ensures that this flow
# does not fail even if the Document has not yet been indexed.
for doc_id in document_ids:
if doc_id not in doc_access:
doc_access[doc_id] = get_null_document_access()
return doc_access
def get_access_for_documents(
document_ids: list[str],
@@ -42,7 +97,7 @@ def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
matches one entry in the returned set.
"""
if user:
return {prefix_user(str(user.id)), PUBLIC_DOC_PAT}
return {prefix_user_email(user.email), PUBLIC_DOC_PAT}
return {PUBLIC_DOC_PAT}

View File

@@ -1,30 +1,72 @@
from dataclasses import dataclass
from uuid import UUID
from danswer.access.utils import prefix_user
from danswer.access.utils import prefix_external_group
from danswer.access.utils import prefix_user_email
from danswer.access.utils import prefix_user_group
from danswer.configs.constants import PUBLIC_DOC_PAT
@dataclass(frozen=True)
class DocumentAccess:
user_ids: set[str] # stringified UUIDs
user_groups: set[str] # names of user groups associated with this document
class ExternalAccess:
# Emails of external users with access to the doc externally
external_user_emails: set[str]
# Names or external IDs of groups with access to the doc
external_user_group_ids: set[str]
# Whether the document is public in the external system or Danswer
is_public: bool
def to_acl(self) -> list[str]:
return (
[prefix_user(user_id) for user_id in self.user_ids]
@dataclass(frozen=True)
class DocumentAccess(ExternalAccess):
# User emails for Danswer users, None indicates admin
user_emails: set[str | None]
# Names of user groups associated with this document
user_groups: set[str]
def to_acl(self) -> set[str]:
return set(
[
prefix_user_email(user_email)
for user_email in self.user_emails
if user_email
]
+ [prefix_user_group(group_name) for group_name in self.user_groups]
+ [
prefix_user_email(user_email)
for user_email in self.external_user_emails
]
+ [
# The group names are already prefixed by the source type
# This adds an additional prefix of "external_group:"
prefix_external_group(group_name)
for group_name in self.external_user_group_ids
]
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
)
@classmethod
def build(
cls, user_ids: list[UUID | None], user_groups: list[str], is_public: bool
cls,
user_emails: list[str | None],
user_groups: list[str],
external_user_emails: list[str],
external_user_group_ids: list[str],
is_public: bool,
) -> "DocumentAccess":
return cls(
user_ids={str(user_id) for user_id in user_ids if user_id},
external_user_emails={
prefix_user_email(external_email)
for external_email in external_user_emails
},
external_user_group_ids={
prefix_external_group(external_group_id)
for external_group_id in external_user_group_ids
},
user_emails={
prefix_user_email(user_email)
for user_email in user_emails
if user_email
},
user_groups=set(user_groups),
is_public=is_public,
)

View File

@@ -1,10 +1,24 @@
def prefix_user(user_id: str) -> str:
"""Prefixes a user ID to eliminate collision with group names.
This assumes that groups are prefixed with a different prefix."""
return f"user_id:{user_id}"
from danswer.configs.constants import DocumentSource
def prefix_user_email(user_email: str) -> str:
"""Prefixes a user email to eliminate collision with group names.
This applies to both a Danswer user and an External user, this is to make the query time
more efficient"""
return f"user_email:{user_email}"
def prefix_user_group(user_group_name: str) -> str:
"""Prefixes a user group name to eliminate collision with user IDs.
"""Prefixes a user group name to eliminate collision with user emails.
This assumes that user ids are prefixed with a different prefix."""
return f"group:{user_group_name}"
def prefix_external_group(ext_group_name: str) -> str:
"""Prefixes an external group name to eliminate collision with user emails / Danswer groups."""
return f"external_group:{ext_group_name}"
def prefix_group_w_source(ext_group_name: str, source: DocumentSource) -> str:
"""External groups may collide across sources, every source needs its own prefix."""
return f"{source.value.upper()}_{ext_group_name}"

View File

@@ -33,7 +33,9 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
has_web_login: bool | None = True
class UserUpdate(schemas.BaseUserUpdate):
role: UserRole
has_web_login: bool | None = True

View File

@@ -16,7 +16,9 @@ from fastapi import HTTPException
from fastapi import Request
from fastapi import Response
from fastapi import status
from fastapi.security import OAuth2PasswordRequestForm
from fastapi_users import BaseUserManager
from fastapi_users import exceptions
from fastapi_users import FastAPIUsers
from fastapi_users import models
from fastapi_users import schemas
@@ -33,6 +35,7 @@ from sqlalchemy.orm import Session
from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
from danswer.auth.schemas import UserUpdate
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import EMAIL_FROM
@@ -67,23 +70,6 @@ from danswer.utils.variable_functionality import fetch_versioned_implementation
logger = setup_logger()
def validate_curator_request(groups: list | None, is_public: bool) -> None:
if is_public:
detail = "Curators cannot create public objects"
logger.error(detail)
raise HTTPException(
status_code=401,
detail=detail,
)
if not groups:
detail = "Curators must specify 1+ groups"
logger.error(detail)
raise HTTPException(
status_code=401,
detail=detail,
)
def is_user_admin(user: User | None) -> bool:
if AUTH_TYPE == AuthType.DISABLED:
return True
@@ -201,7 +187,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_create: schemas.UC | UserCreate,
safe: bool = False,
request: Optional[Request] = None,
) -> models.UP:
) -> User:
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if hasattr(user_create, "role"):
@@ -210,7 +196,27 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
return await super().create(user_create, safe=safe, request=request) # type: ignore
user = None
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
if (
not user.has_web_login
and hasattr(user_create, "has_web_login")
and user_create.has_web_login
):
user_update = UserUpdate(
password=user_create.password,
has_web_login=True,
role=user_create.role,
is_verified=user_create.is_verified,
)
user = await self.update(user_update, user)
else:
raise exceptions.UserAlreadyExists()
return user
async def oauth_callback(
self: "BaseUserManager[models.UOAP, models.ID]",
@@ -251,6 +257,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
if user.oidc_expiry and not TRACK_EXTERNAL_IDP_EXPIRY:
await self.user_db.update(user, update_dict={"oidc_expiry": None})
# Handle case where user has used product outside of web and is now creating an account through web
if not user.has_web_login:
await self.user_db.update(
user,
update_dict={
"is_verified": is_verified_by_default,
"has_web_login": True,
},
)
user.is_verified = is_verified_by_default
user.has_web_login = True
return user
async def on_after_register(
@@ -279,6 +297,32 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
send_user_verification_email(user.email, token)
async def authenticate(
self, credentials: OAuth2PasswordRequestForm
) -> Optional[User]:
try:
user = await self.get_by_email(credentials.username)
except exceptions.UserNotExists:
self.password_helper.hash(credentials.password)
return None
if not user.has_web_login:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
)
verified, updated_password_hash = self.password_helper.verify_and_update(
credentials.password, user.hashed_password
)
if not verified:
return None
if updated_password_hash is not None:
await self.user_db.update(user, {"hashed_password": updated_password_hash})
return user
async def get_user_manager(
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
@@ -381,6 +425,7 @@ async def optional_user(
async def double_check_user(
user: User | None,
optional: bool = DISABLE_AUTH,
include_expired: bool = False,
) -> User | None:
if optional:
return None
@@ -397,7 +442,11 @@ async def double_check_user(
detail="Access denied. User is not verified.",
)
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
if (
user.oidc_expiry
and user.oidc_expiry < datetime.now(timezone.utc)
and not include_expired
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User's OIDC token has expired.",
@@ -406,6 +455,12 @@ async def double_check_user(
return user
async def current_user_with_expired_token(
user: User | None = Depends(optional_user),
) -> User | None:
return await double_check_user(user, include_expired=True)
async def current_user(
user: User | None = Depends(optional_user),
) -> User | None:

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,361 @@
# These are helper objects for tracking the keys we need to write in redis
import time
from abc import ABC
from abc import abstractmethod
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.celeryconfig import CELERY_SEPARATOR
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import construct_document_select_for_connector_credential_pair
from danswer.db.document import (
construct_document_select_for_connector_credential_pair_by_needs_sync,
)
from danswer.db.document_set import construct_document_select_by_docset
from danswer.utils.variable_functionality import fetch_versioned_implementation
class RedisObjectHelper(ABC):
PREFIX = "base"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int):
self._id: int = id
@property
def task_id_prefix(self) -> str:
return f"{self.PREFIX}_{self._id}"
@property
def fence_key(self) -> str:
# example: documentset_fence_1
return f"{self.FENCE_PREFIX}_{self._id}"
@property
def taskset_key(self) -> str:
# example: documentset_taskset_1
return f"{self.TASKSET_PREFIX}_{self._id}"
@staticmethod
def get_id_from_fence_key(key: str) -> int | None:
"""
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
Args:
key (str): The fence key string.
Returns:
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
"""
parts = key.split("_")
if len(parts) != 3:
return None
try:
object_id = int(parts[2])
except ValueError:
return None
return object_id
@staticmethod
def get_id_from_task_id(task_id: str) -> int | None:
"""
Extracts the object ID from a task ID string.
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
- `objectid` is the ID you want to extract,
- `suffix` is another arbitrary string (e.g., a UUID).
Example:
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
this method will return the string `"1"`.
Args:
task_id (str): The task ID string from which to extract the object ID.
Returns:
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
"""
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
parts = task_id.split("_")
if len(parts) != 3:
return None
try:
object_id = int(parts[1])
except ValueError:
return None
return object_id
@abstractmethod
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
) -> int | None:
pass
class RedisDocumentSet(RedisObjectHelper):
PREFIX = "documentset"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
stmt = construct_document_select_by_docset(self._id, current_only=False)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
async_results.append(result)
return len(async_results)
class RedisUserGroup(RedisObjectHelper):
PREFIX = "usergroup"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
try:
construct_document_select_by_usergroup = fetch_versioned_implementation(
"danswer.db.user_group",
"construct_document_select_by_usergroup",
)
except ModuleNotFoundError:
return 0
stmt = construct_document_select_by_usergroup(self._id)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
async_results.append(result)
return len(async_results)
class RedisConnectorCredentialPair(RedisObjectHelper):
"""This class differs from the default in that the taskset used spans
all connectors and is not per connector."""
PREFIX = "connectorsync"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
@classmethod
def get_fence_key(cls) -> str:
return RedisConnectorCredentialPair.FENCE_PREFIX
@classmethod
def get_taskset_key(cls) -> str:
return RedisConnectorCredentialPair.TASKSET_PREFIX
@property
def taskset_key(self) -> str:
"""Notice that this is intentionally reusing the same taskset for all
connector syncs"""
# example: connector_taskset
return f"{self.TASKSET_PREFIX}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
if not cc_pair:
return None
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
class RedisConnectorDeletion(RedisObjectHelper):
PREFIX = "connectordeletion"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
if not cc_pair:
return None
stmt = construct_document_select_for_connector_credential_pair(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(self.taskset_key, custom_task_id)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"document_by_cc_pair_cleanup_task",
kwargs=dict(
document_id=doc.id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
def celery_get_queue_length(queue: str, r: Redis) -> int:
"""This is a redis specific way to get the length of a celery queue.
It is priority aware and knows how to count across the multiple redis lists
used to implement task prioritization.
This operation is not atomic."""
total_length = 0
for i in range(len(DanswerCeleryPriority)):
queue_name = queue
if i > 0:
queue_name += CELERY_SEPARATOR
queue_name += str(i)
length = r.llen(queue_name)
total_length += cast(int, length)
return total_length

View File

@@ -3,9 +3,8 @@ from datetime import timezone
from sqlalchemy.orm import Session
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.task_utils import name_cc_prune_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
@@ -16,30 +15,44 @@ from danswer.connectors.interfaces import IdConnector
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import Document
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.engine import get_db_current_time
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import TaskStatus
from danswer.db.models import Connector
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential
from danswer.db.models import DocumentSet
from danswer.db.models import TaskQueueState
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.db.tasks import get_latest_task_by_type
from danswer.redis.redis_pool import RedisPool
from danswer.server.documents.models import DeletionAttemptSnapshot
from danswer.utils.logger import setup_logger
logger = setup_logger()
redis_pool = RedisPool()
def _get_deletion_status(
connector_id: int, credential_id: int, db_session: Session
) -> TaskQueueState | None:
cleanup_task_name = name_cc_cleanup_task(
connector_id=connector_id, credential_id=credential_id
"""We no longer store TaskQueueState in the DB for a deletion attempt.
This function populates TaskQueueState by just checking redis.
"""
cc_pair = get_connector_credential_pair(
connector_id=connector_id, credential_id=credential_id, db_session=db_session
)
if not cc_pair:
return None
rcd = RedisConnectorDeletion(cc_pair.id)
r = redis_pool.get_client()
if not r.exists(rcd.fence_key):
return None
return TaskQueueState(
task_id="", task_name=rcd.fence_key, status=TaskStatus.STARTED
)
return get_latest_task(task_name=cleanup_task_name, db_session=db_session)
def get_deletion_attempt_snapshot(
@@ -56,46 +69,6 @@ def get_deletion_attempt_snapshot(
)
def should_kick_off_deletion_of_cc_pair(
cc_pair: ConnectorCredentialPair, db_session: Session
) -> bool:
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
return False
if check_deletion_attempt_is_allowed(cc_pair, db_session):
return False
deletion_task = _get_deletion_status(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
db_session=db_session,
)
if deletion_task and check_task_is_live_and_not_timed_out(
deletion_task,
db_session,
# 1 hour timeout
timeout=60 * 60,
):
return False
return True
def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool:
if document_set.is_up_to_date:
return False
task_name = name_document_set_sync_task(document_set.id)
latest_sync = get_latest_task(task_name, db_session)
if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session):
logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.")
return False
logger.info(f"Document set {document_set.id} syncing now.")
return True
def should_prune_cc_pair(
connector: Connector, credential: Credential, db_session: Session
) -> bool:

View File

@@ -0,0 +1,76 @@
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
from danswer.configs.app_configs import CELERY_RESULT_EXPIRES
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY_RESULT_BACKEND
from danswer.configs.app_configs import REDIS_HOST
from danswer.configs.app_configs import REDIS_PASSWORD
from danswer.configs.app_configs import REDIS_PORT
from danswer.configs.app_configs import REDIS_SSL
from danswer.configs.app_configs import REDIS_SSL_CA_CERTS
from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
from danswer.configs.constants import DanswerCeleryPriority
CELERY_SEPARATOR = ":"
CELERY_PASSWORD_PART = ""
if REDIS_PASSWORD:
CELERY_PASSWORD_PART = f":{REDIS_PASSWORD}@"
REDIS_SCHEME = "redis"
# SSL-specific query parameters for Redis URL
SSL_QUERY_PARAMS = ""
if REDIS_SSL:
REDIS_SCHEME = "rediss"
SSL_QUERY_PARAMS = f"?ssl_cert_reqs={REDIS_SSL_CERT_REQS}"
if REDIS_SSL_CA_CERTS:
SSL_QUERY_PARAMS += f"&ssl_ca_certs={REDIS_SSL_CA_CERTS}"
# example celery_broker_url: "redis://:password@localhost:6379/15"
broker_url = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}"
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY_RESULT_BACKEND}{SSL_QUERY_PARAMS}"
# NOTE: prefetch 4 is significantly faster than prefetch 1 for small tasks
# however, prefetching is bad when tasks are lengthy as those tasks
# can stall other tasks.
worker_prefetch_multiplier = 4
broker_transport_options = {
"priority_steps": list(range(len(DanswerCeleryPriority))),
"sep": CELERY_SEPARATOR,
"queue_order_strategy": "priority",
}
task_default_priority = DanswerCeleryPriority.MEDIUM
task_acks_late = True
# It's possible we don't even need celery's result backend, in which case all of the optimization below
# might be irrelevant
result_expires = CELERY_RESULT_EXPIRES # 86400 seconds is the default
# Option 0: Defaults (json serializer, no compression)
# about 1.5 KB per queued task. 1KB in queue, 400B for result, 100 as a child entry in generator result
# Option 1: Reduces generator task result sizes by roughly 20%
# task_compression = "bzip2"
# task_serializer = "pickle"
# result_compression = "bzip2"
# result_serializer = "pickle"
# accept_content=["pickle"]
# Option 2: this significantly reduces the size of the result for generator tasks since the list of children
# can be large. small tasks change very little
# def pickle_bz2_encoder(data):
# return bz2.compress(pickle.dumps(data))
# def pickle_bz2_decoder(data):
# return pickle.loads(bz2.decompress(data))
# from kombu import serialization # To register custom serialization with Celery/Kombu
# serialization.register('pickle-bzip2', pickle_bz2_encoder, pickle_bz2_decoder, 'application/x-pickle-bz2', 'binary')
# task_serializer = "pickle-bzip2"
# result_serializer = "pickle-bzip2"
# accept_content=["pickle", "pickle-bzip2"]

View File

@@ -13,28 +13,16 @@ connector / credential pair from the access list
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_documents
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit,
)
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document_connector_cnts
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.document import get_document_connector_counts
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.models import ConnectorCredentialPair
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import UpdateRequest
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from danswer.utils.variable_functionality import noop_fallback
logger = setup_logger()
@@ -57,13 +45,15 @@ def delete_connector_credential_pair_batch(
with prepare_to_modify_documents(
db_session=db_session, document_ids=document_ids
):
document_connector_cnts = get_document_connector_cnts(
document_connector_counts = get_document_connector_counts(
db_session=db_session, document_ids=document_ids
)
# figure out which docs need to be completely deleted
document_ids_to_delete = [
document_id for document_id, cnt in document_connector_cnts if cnt == 1
document_id
for document_id, cnt in document_connector_counts
if cnt == 1
]
logger.debug(f"Deleting documents: {document_ids_to_delete}")
@@ -76,7 +66,7 @@ def delete_connector_credential_pair_batch(
# figure out which docs need to be updated
document_ids_to_update = [
document_id for document_id, cnt in document_connector_cnts if cnt > 1
document_id for document_id, cnt in document_connector_counts if cnt > 1
]
# maps document id to list of document set names
@@ -109,7 +99,7 @@ def delete_connector_credential_pair_batch(
document_index.update(update_requests=update_requests)
# clean up Postgres
delete_document_by_connector_credential_pair__no_commit(
delete_documents_by_connector_credential_pair__no_commit(
db_session=db_session,
document_ids=document_ids_to_update,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
@@ -118,79 +108,3 @@ def delete_connector_credential_pair_batch(
),
)
db_session.commit()
def delete_connector_credential_pair(
db_session: Session,
document_index: DocumentIndex,
cc_pair: ConnectorCredentialPair,
) -> int:
connector_id = cc_pair.connector_id
credential_id = cc_pair.credential_id
num_docs_deleted = 0
while True:
documents = get_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
limit=_DELETION_BATCH_SIZE,
)
if not documents:
break
delete_connector_credential_pair_batch(
document_ids=[document.id for document in documents],
connector_id=connector_id,
credential_id=credential_id,
document_index=document_index,
)
num_docs_deleted += len(documents)
# clean up the rest of the related Postgres entities
# index attempts
delete_index_attempts(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
# document sets
delete_document_set_cc_pair_relationship__no_commit(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
# user groups
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
"danswer.db.user_group",
"delete_user_group_cc_pair_relationship__no_commit",
noop_fallback,
)
cleanup_user_groups(
cc_pair_id=cc_pair.id,
db_session=db_session,
)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
# if there are no credentials left, delete the connector
connector = fetch_connector_by_id(
db_session=db_session,
connector_id=connector_id,
)
if not connector or not len(connector.credentials):
logger.info("Found no credentials left for connector, deleting connector")
db_session.delete(connector)
db_session.commit()
logger.notice(
"Successfully deleted connector_credential_pair with connector_id:"
f" '{connector_id}' and credential_id: '{credential_id}'. Deleted {num_docs_deleted} docs."
)
return num_docs_deleted

View File

@@ -56,11 +56,11 @@ def _get_connector_runner(
try:
runnable_connector = instantiate_connector(
attempt.connector_credential_pair.connector.source,
task,
attempt.connector_credential_pair.connector.connector_specific_config,
attempt.connector_credential_pair.credential,
db_session,
db_session=db_session,
source=attempt.connector_credential_pair.connector.source,
input_type=task,
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
credential=attempt.connector_credential_pair.credential,
)
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")
@@ -384,17 +384,22 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
return attempt
def run_indexing_entrypoint(index_attempt_id: int, is_ee: bool = False) -> None:
def run_indexing_entrypoint(
index_attempt_id: int, connector_credential_pair_id: int, is_ee: bool = False
) -> None:
"""Entrypoint for indexing run when using dask distributed.
Wraps the actual logic in a `try` block so that we can catch any exceptions
and mark the attempt as failed."""
try:
if is_ee:
global_version.set_ee()
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
IndexAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
)
with Session(get_sqlalchemy_engine()) as db_session:
# make sure that it is valid to run this indexing attempt + mark it

View File

@@ -14,14 +14,6 @@ from danswer.db.tasks import mark_task_start
from danswer.db.tasks import register_task
def name_cc_cleanup_task(connector_id: int, credential_id: int) -> str:
return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"
def name_document_set_sync_task(document_set_id: int) -> str:
return f"sync_doc_set_{document_set_id}"
def name_cc_prune_task(
connector_id: int | None = None, credential_id: int | None = None
) -> str:
@@ -93,9 +85,16 @@ def build_apply_async_wrapper(build_name_fn: Callable[..., str]) -> Callable[[AA
kwargs_for_build_name = kwargs or {}
task_name = build_name_fn(*args_for_build_name, **kwargs_for_build_name)
with Session(get_sqlalchemy_engine()) as db_session:
# mark the task as started
# register_task must come before fn = apply_async or else the task
# might run mark_task_start (and crash) before the task row exists
db_task = register_task(task_name, db_session)
task = fn(args, kwargs, *other_args, **other_kwargs)
register_task(task.id, task_name, db_session)
# we update the celery task id for diagnostic purposes
# but it isn't currently used by any code
db_task.task_id = task.id
db_session.commit()
return task

View File

@@ -47,7 +47,6 @@ from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import LOG_LEVEL
from shared_configs.configs import MODEL_SERVER_PORT
logger = setup_logger()
# If the indexing dies, it's most likely due to resource constraints,
@@ -68,7 +67,7 @@ def _should_create_new_indexing(
) -> bool:
connector = cc_pair.connector
# don't kick off indexing for `NOT_APPLICABLE`
# don't kick off indexing for `NOT_APPLICABLE` sources
if connector.source == DocumentSource.NOT_APPLICABLE:
return False
@@ -212,7 +211,6 @@ def cleanup_indexing_jobs(
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
) -> dict[int, Future | SimpleJob]:
existing_jobs_copy = existing_jobs.copy()
# clean up completed jobs
with Session(get_sqlalchemy_engine()) as db_session:
for attempt_id, job in existing_jobs.items():
@@ -313,7 +311,12 @@ def kickoff_indexing_jobs(
indexing_attempt_count = 0
primary_client_full = False
secondary_client_full = False
for attempt, search_settings in new_indexing_attempts:
if primary_client_full and secondary_client_full:
break
use_secondary_index = (
search_settings.status == IndexModelStatus.FUTURE
if search_settings is not None
@@ -338,20 +341,28 @@ def kickoff_indexing_jobs(
)
continue
if use_secondary_index:
run = secondary_client.submit(
run_indexing_entrypoint,
attempt.id,
global_version.get_is_ee_version(),
pure=False,
)
if not use_secondary_index:
if not primary_client_full:
run = client.submit(
run_indexing_entrypoint,
attempt.id,
attempt.connector_credential_pair_id,
global_version.get_is_ee_version(),
pure=False,
)
if not run:
primary_client_full = True
else:
run = client.submit(
run_indexing_entrypoint,
attempt.id,
global_version.get_is_ee_version(),
pure=False,
)
if not secondary_client_full:
run = secondary_client.submit(
run_indexing_entrypoint,
attempt.id,
attempt.connector_credential_pair_id,
global_version.get_is_ee_version(),
pure=False,
)
if not run:
secondary_client_full = True
if run:
if indexing_attempt_count == 0:

View File

@@ -122,7 +122,7 @@ def load_personas_from_yaml(
prompt_ids=prompt_ids,
document_set_ids=doc_set_ids,
tool_ids=tool_ids,
default_persona=True,
builtin_persona=True,
is_public=True,
display_priority=existing_persona.display_priority
if existing_persona is not None

View File

@@ -1,5 +1,6 @@
from collections.abc import Iterator
from datetime import datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel
@@ -44,8 +45,26 @@ class QADocsResponse(RetrievalDocs):
return initial_dict
class StreamStopReason(Enum):
CONTEXT_LENGTH = "context_length"
CANCELLED = "cancelled"
class StreamStopInfo(BaseModel):
stop_reason: StreamStopReason
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
data = super().model_dump(mode="json", *args, **kwargs) # type: ignore
data["stop_reason"] = self.stop_reason.name
return data
class LLMRelevanceFilterResponse(BaseModel):
relevant_chunk_indices: list[int]
llm_selected_doc_indices: list[int]
class FinalUsedContextDocsResponse(BaseModel):
final_context_docs: list[LlmDoc]
class RelevanceAnalysis(BaseModel):
@@ -78,6 +97,16 @@ class CitationInfo(BaseModel):
document_id: str
class AllCitations(BaseModel):
citations: list[CitationInfo]
# This is a mapping of the citation number to the document index within
# the result search doc set
class MessageSpecificCitations(BaseModel):
citation_map: dict[int, int]
class MessageResponseIDInfo(BaseModel):
user_message_id: int | None
reserved_assistant_message_id: int
@@ -123,7 +152,7 @@ class QAResponse(SearchResponse, DanswerAnswer):
predicted_flow: QueryFlow
predicted_search: SearchType
eval_res_valid: bool | None = None
llm_chunks_indices: list[int] | None = None
llm_selected_doc_indices: list[int] | None = None
error_msg: str | None = None
@@ -144,6 +173,7 @@ AnswerQuestionPossibleReturn = (
| ImageGenerationDisplay
| CustomToolResponse
| StreamingError
| StreamStopInfo
)

View File

@@ -7,12 +7,15 @@ from typing import cast
from sqlalchemy.orm import Session
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.models import AllCitations
from danswer.chat.models import CitationInfo
from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import FinalUsedContextDocsResponse
from danswer.chat.models import ImageGenerationDisplay
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import MessageResponseIDInfo
from danswer.chat.models import MessageSpecificCitations
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import BING_API_KEY
@@ -70,7 +73,9 @@ from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
from danswer.tools.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from danswer.tools.custom.custom_tool import CustomToolCallSummary
from danswer.tools.force import ForceUseTool
@@ -85,6 +90,8 @@ from danswer.tools.internet_search.internet_search_tool import (
)
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
@@ -100,9 +107,9 @@ from danswer.utils.timing import log_generator_function_time
logger = setup_logger()
def translate_citations(
def _translate_citations(
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
) -> dict[int, int]:
) -> MessageSpecificCitations:
"""Always cites the first instance of the document_id, assumes the db_docs
are sorted in the order displayed in the UI"""
doc_id_to_saved_doc_id_map: dict[str, int] = {}
@@ -117,7 +124,7 @@ def translate_citations(
citation.citation_num
] = doc_id_to_saved_doc_id_map[citation.document_id]
return citation_to_saved_doc_id_map
return MessageSpecificCitations(citation_map=citation_to_saved_doc_id_map)
def _handle_search_tool_response_summary(
@@ -239,11 +246,14 @@ ChatPacket = (
StreamingError
| QADocsResponse
| LLMRelevanceFilterResponse
| FinalUsedContextDocsResponse
| ChatMessageDetail
| DanswerAnswerPiece
| AllCitations
| CitationInfo
| ImageGenerationDisplay
| CustomToolResponse
| MessageSpecificCitations
| MessageResponseIDInfo
)
ChatPacketStream = Iterator[ChatPacket]
@@ -263,6 +273,7 @@ def stream_chat_message_objects(
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
enforce_chat_session_id_for_search_docs: bool = True,
) -> ChatPacketStream:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
@@ -434,6 +445,7 @@ def stream_chat_message_objects(
chat_session=chat_session,
user_id=user_id,
db_session=db_session,
enforce_chat_session_id_for_search_docs=enforce_chat_session_id_for_search_docs,
)
# Generates full documents currently
@@ -597,8 +609,13 @@ def stream_chat_message_objects(
if db_tool_model.openapi_schema:
tool_dict[db_tool_model.id] = cast(
list[Tool],
build_custom_tools_from_openapi_schema(
db_tool_model.openapi_schema
build_custom_tools_from_openapi_schema_and_headers(
db_tool_model.openapi_schema,
dynamic_schema_info=DynamicSchemaInfo(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
),
custom_headers=db_tool_model.custom_headers,
),
)
@@ -663,9 +680,11 @@ def stream_chat_message_objects(
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
dedupe_docs=retrieval_options.dedupe_docs
if retrieval_options
else False,
dedupe_docs=(
retrieval_options.dedupe_docs
if retrieval_options
else False
),
)
yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
@@ -688,9 +707,14 @@ def stream_chat_message_objects(
)
yield LLMRelevanceFilterResponse(
relevant_chunk_indices=llm_indices
llm_selected_doc_indices=llm_indices
)
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(
final_context_docs=packet.response
)
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], packet.response
@@ -727,10 +751,18 @@ def stream_chat_message_objects(
tool_result = packet
yield cast(ChatPacket, packet)
logger.debug("Reached end of stream")
except Exception as e:
error_msg = str(e)
logger.exception(f"Failed to process chat message: {error_msg}")
except ValueError as e:
logger.exception("Failed to process chat message.")
error_msg = str(e)
yield StreamingError(error=error_msg)
db_session.rollback()
return
except Exception as e:
logger.exception("Failed to process chat message.")
error_msg = str(e)
stack_trace = traceback.format_exc()
client_error_msg = litellm_exception_to_error_msg(e, llm)
if llm.config.api_key and len(llm.config.api_key) > 2:
@@ -743,12 +775,13 @@ def stream_chat_message_objects(
# Post-LLM answer processing
try:
db_citations = None
message_specific_citations: MessageSpecificCitations | None = None
if reference_db_search_docs:
db_citations = translate_citations(
message_specific_citations = _translate_citations(
citations_list=answer.citations,
db_docs=reference_db_search_docs,
)
yield AllCitations(citations=answer.citations)
# Saving Gen AI answer and responding with message info
tool_name_to_tool_id: dict[str, int] = {}
@@ -765,18 +798,22 @@ def stream_chat_message_objects(
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=db_citations,
citations=message_specific_citations.citation_map
if message_specific_citations
else None,
error=None,
tool_calls=[
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
if tool_result
else [],
tool_calls=(
[
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
if tool_result
else []
),
)
logger.debug("Committing messages")

View File

@@ -126,6 +126,7 @@ try:
except ValueError:
INDEX_BATCH_SIZE = 16
# Below are intended to match the env variables names used by the official postgres docker image
# https://hub.docker.com/_/postgres
POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres"
@@ -149,6 +150,27 @@ try:
except ValueError:
POSTGRES_POOL_RECYCLE = POSTGRES_POOL_RECYCLE_DEFAULT
REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true"
REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
# Used for general redis things
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))
# Used by celery as broker and backend
REDIS_DB_NUMBER_CELERY_RESULT_BACKEND = int(
os.environ.get("REDIS_DB_NUMBER_CELERY_RESULT_BACKEND", 14)
)
REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15)) # broker
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
# should be one of "required", "optional", or "none"
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "none")
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", "")
CELERY_RESULT_EXPIRES = int(os.environ.get("CELERY_RESULT_EXPIRES", 86400)) # seconds
#####
# Connector Configs
#####

View File

@@ -83,8 +83,15 @@ DISABLE_LLM_DOC_RELEVANCE = (
# Stops streaming answers back to the UI if this pattern is seen:
STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None
# The backend logic for this being True isn't fully supported yet
HARD_DELETE_CHATS = False
# Set this to "true" to hard delete chats
# This will make chats unviewable by admins after a user deletes them
# As opposed to soft deleting them, which just hides them from non-admin users
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "").lower() == "true"
# Internet Search
BING_API_KEY = os.environ.get("BING_API_KEY") or None
# Enable in-house model for detecting connector-based filtering in queries
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)
VESPA_SEARCHER_THREADS = int(os.environ.get("VESPA_SEARCHER_THREADS") or 2)

View File

@@ -57,9 +57,12 @@ KV_SLACK_BOT_TOKENS_CONFIG_KEY = "slack_bot_tokens_config_key"
KV_GEN_AI_KEY_CHECK_TIME = "genai_api_key_last_check_time"
KV_SETTINGS_KEY = "danswer_settings"
KV_CUSTOMER_UUID_KEY = "customer_uuid"
KV_INSTANCE_DOMAIN_KEY = "instance_domain"
KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings"
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
class DocumentSource(str, Enum):
# Special case, document passed in via Danswer APIs without specifying a source type
@@ -96,6 +99,7 @@ class DocumentSource(str, Enum):
CLICKUP = "clickup"
MEDIAWIKI = "mediawiki"
WIKIPEDIA = "wikipedia"
ASANA = "asana"
S3 = "s3"
R2 = "r2"
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
@@ -130,6 +134,12 @@ class AuthType(str, Enum):
SAML = "saml"
class SessionType(str, Enum):
CHAT = "Chat"
SEARCH = "Search"
SLACK = "Slack"
class QAFeedbackType(str, Enum):
LIKE = "like" # User likes the answer, used for metrics
DISLIKE = "dislike" # User dislikes the answer, used for metrics
@@ -166,3 +176,25 @@ class FileOrigin(str, Enum):
class PostgresAdvisoryLocks(Enum):
KOMBU_MESSAGE_CLEANUP_LOCK_ID = auto()
class DanswerCeleryQueues:
VESPA_DOCSET_SYNC_GENERATOR = "vespa_docset_sync_generator"
VESPA_USERGROUP_SYNC_GENERATOR = "vespa_usergroup_sync_generator"
VESPA_METADATA_SYNC = "vespa_metadata_sync"
CONNECTOR_DELETION = "connector_deletion"
class DanswerRedisLocks:
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
MONITOR_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:monitor_connector_deletion_beat"
class DanswerCeleryPriority(int, Enum):
HIGHEST = 0
HIGH = auto()
MEDIUM = auto()
LOW = auto()
LOWEST = auto()

View File

@@ -39,9 +39,13 @@ SIM_SCORE_RANGE_HIGH = float(os.environ.get("SIM_SCORE_RANGE_HIGH") or 1.0)
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "search_query: ")
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "search_document: ")
# Purely an optimization, memory limitation consideration
BATCH_SIZE_ENCODE_CHUNKS = 8
# User's set embedding batch size overrides the default encoding batch sizes
EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE") or 0) or None
BATCH_SIZE_ENCODE_CHUNKS = EMBEDDING_BATCH_SIZE or 8
# don't send over too many chunks at once, as sending too many could cause timeouts
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES = 512
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES = EMBEDDING_BATCH_SIZE or 512
# For score display purposes, only way is to know the expected ranges
CROSS_ENCODER_RANGE_MAX = 1
CROSS_ENCODER_RANGE_MIN = 0
@@ -51,33 +55,11 @@ CROSS_ENCODER_RANGE_MIN = 0
# Generative AI Model Configs
#####
# If changing GEN_AI_MODEL_PROVIDER or GEN_AI_MODEL_VERSION from the default,
# be sure to use one that is LiteLLM compatible:
# https://litellm.vercel.app/docs/providers/azure#completion---using-env-variables
# The provider is the prefix before / in the model argument
# Additionally Danswer supports GPT4All and custom request library based models
# Set GEN_AI_MODEL_PROVIDER to "custom" to use the custom requests approach
# Set GEN_AI_MODEL_PROVIDER to "gpt4all" to use gpt4all models running locally
GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai"
# If using Azure, it's the engine name, for example: Danswer
# NOTE: the 3 below should only be used for dev.
GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY")
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION")
# For secondary flows like extracting filters or deciding if a chunk is useful, we don't need
# as powerful of a model as say GPT-4 so we can use an alternative that is faster and cheaper
FAST_GEN_AI_MODEL_VERSION = os.environ.get("FAST_GEN_AI_MODEL_VERSION")
# If the Generative AI model requires an API key for access, otherwise can leave blank
GEN_AI_API_KEY = (
os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY")) or None
)
# API Base, such as (for Azure): https://danswer.openai.azure.com/
GEN_AI_API_ENDPOINT = os.environ.get("GEN_AI_API_ENDPOINT") or None
# API Version, such as (for Azure): 2023-09-15-preview
GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None
# LiteLLM custom_llm_provider
GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None
# Override the auto-detection of LLM max context length
GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None

View File

@@ -59,6 +59,8 @@ if __name__ == "__main__":
latest_docs = test_connector.poll_source(one_day_ago, current)
```
> Note: Be sure to set PYTHONPATH to danswer/backend before running the above main.
### Additional Required Changes:
#### Backend Changes

View File

@@ -0,0 +1,233 @@
import time
from collections.abc import Iterator
from datetime import datetime
from typing import Dict
import asana # type: ignore
from danswer.utils.logger import setup_logger
logger = setup_logger()
# https://github.com/Asana/python-asana/tree/master?tab=readme-ov-file#documentation-for-api-endpoints
class AsanaTask:
def __init__(
self,
id: str,
title: str,
text: str,
link: str,
last_modified: datetime,
project_gid: str,
project_name: str,
) -> None:
self.id = id
self.title = title
self.text = text
self.link = link
self.last_modified = last_modified
self.project_gid = project_gid
self.project_name = project_name
def __str__(self) -> str:
return f"ID: {self.id}\nTitle: {self.title}\nLast modified: {self.last_modified}\nText: {self.text}"
class AsanaAPI:
def __init__(
self, api_token: str, workspace_gid: str, team_gid: str | None
) -> None:
self._user = None # type: ignore
self.workspace_gid = workspace_gid
self.team_gid = team_gid
self.configuration = asana.Configuration()
self.api_client = asana.ApiClient(self.configuration)
self.tasks_api = asana.TasksApi(self.api_client)
self.stories_api = asana.StoriesApi(self.api_client)
self.users_api = asana.UsersApi(self.api_client)
self.project_api = asana.ProjectsApi(self.api_client)
self.workspaces_api = asana.WorkspacesApi(self.api_client)
self.api_error_count = 0
self.configuration.access_token = api_token
self.task_count = 0
def get_tasks(
self, project_gids: list[str] | None, start_date: str
) -> Iterator[AsanaTask]:
"""Get all tasks from the projects with the given gids that were modified since the given date.
If project_gids is None, get all tasks from all projects in the workspace."""
logger.info("Starting to fetch Asana projects")
projects = self.project_api.get_projects(
opts={
"workspace": self.workspace_gid,
"opt_fields": "gid,name,archived,modified_at",
}
)
start_seconds = int(time.mktime(datetime.now().timetuple()))
projects_list = []
project_count = 0
for project_info in projects:
project_gid = project_info["gid"]
if project_gids is None or project_gid in project_gids:
projects_list.append(project_gid)
else:
logger.debug(
f"Skipping project: {project_gid} - not in accepted project_gids"
)
project_count += 1
if project_count % 100 == 0:
logger.info(f"Processed {project_count} projects")
logger.info(f"Found {len(projects_list)} projects to process")
for project_gid in projects_list:
for task in self._get_tasks_for_project(
project_gid, start_date, start_seconds
):
yield task
logger.info(f"Completed fetching {self.task_count} tasks from Asana")
if self.api_error_count > 0:
logger.warning(
f"Encountered {self.api_error_count} API errors during task fetching"
)
def _get_tasks_for_project(
self, project_gid: str, start_date: str, start_seconds: int
) -> Iterator[AsanaTask]:
project = self.project_api.get_project(project_gid, opts={})
if project["archived"]:
logger.info(f"Skipping archived project: {project['name']} ({project_gid})")
return []
if not project["team"] or not project["team"]["gid"]:
logger.info(
f"Skipping project without a team: {project['name']} ({project_gid})"
)
return []
if project["privacy_setting"] == "private":
if self.team_gid and project["team"]["gid"] != self.team_gid:
logger.info(
f"Skipping private project not in configured team: {project['name']} ({project_gid})"
)
return []
else:
logger.info(
f"Processing private project in configured team: {project['name']} ({project_gid})"
)
simple_start_date = start_date.split(".")[0].split("+")[0]
logger.info(
f"Fetching tasks modified since {simple_start_date} for project: {project['name']} ({project_gid})"
)
opts = {
"opt_fields": "name,memberships,memberships.project,completed_at,completed_by,created_at,"
"created_by,custom_fields,dependencies,due_at,due_on,external,html_notes,liked,likes,"
"modified_at,notes,num_hearts,parent,projects,resource_subtype,resource_type,start_on,"
"workspace,permalink_url",
"modified_since": start_date,
}
tasks_from_api = self.tasks_api.get_tasks_for_project(project_gid, opts)
for data in tasks_from_api:
self.task_count += 1
if self.task_count % 10 == 0:
end_seconds = time.mktime(datetime.now().timetuple())
runtime_seconds = end_seconds - start_seconds
if runtime_seconds > 0:
logger.info(
f"Processed {self.task_count} tasks in {runtime_seconds:.0f} seconds "
f"({self.task_count / runtime_seconds:.2f} tasks/second)"
)
logger.debug(f"Processing Asana task: {data['name']}")
text = self._construct_task_text(data)
try:
text += self._fetch_and_add_comments(data["gid"])
last_modified_date = self.format_date(data["modified_at"])
text += f"Last modified: {last_modified_date}\n"
task = AsanaTask(
id=data["gid"],
title=data["name"],
text=text,
link=data["permalink_url"],
last_modified=datetime.fromisoformat(data["modified_at"]),
project_gid=project_gid,
project_name=project["name"],
)
yield task
except Exception:
logger.error(
f"Error processing task {data['gid']} in project {project_gid}",
exc_info=True,
)
self.api_error_count += 1
def _construct_task_text(self, data: Dict) -> str:
text = f"{data['name']}\n\n"
if data["notes"]:
text += f"{data['notes']}\n\n"
if data["created_by"] and data["created_by"]["gid"]:
creator = self.get_user(data["created_by"]["gid"])["name"]
created_date = self.format_date(data["created_at"])
text += f"Created by: {creator} on {created_date}\n"
if data["due_on"]:
due_date = self.format_date(data["due_on"])
text += f"Due date: {due_date}\n"
if data["completed_at"]:
completed_date = self.format_date(data["completed_at"])
text += f"Completed on: {completed_date}\n"
text += "\n"
return text
def _fetch_and_add_comments(self, task_gid: str) -> str:
text = ""
stories_opts: Dict[str, str] = {}
story_start = time.time()
stories = self.stories_api.get_stories_for_task(task_gid, stories_opts)
story_count = 0
comment_count = 0
for story in stories:
story_count += 1
if story["resource_subtype"] == "comment_added":
comment = self.stories_api.get_story(
story["gid"], opts={"opt_fields": "text,created_by,created_at"}
)
commenter = self.get_user(comment["created_by"]["gid"])["name"]
text += f"Comment by {commenter}: {comment['text']}\n\n"
comment_count += 1
story_duration = time.time() - story_start
logger.debug(
f"Processed {story_count} stories (including {comment_count} comments) in {story_duration:.2f} seconds"
)
return text
def get_user(self, user_gid: str) -> Dict:
if self._user is not None:
return self._user
self._user = self.users_api.get_user(user_gid, {"opt_fields": "name,email"})
if not self._user:
logger.warning(f"Unable to fetch user information for user_gid: {user_gid}")
return {"name": "Unknown"}
return self._user
def format_date(self, date_str: str) -> str:
date = datetime.fromisoformat(date_str)
return time.strftime("%Y-%m-%d", date.timetuple())
def get_time(self) -> str:
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())

View File

@@ -0,0 +1,120 @@
import datetime
from typing import Any
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.asana import asana_api
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.logger import setup_logger
logger = setup_logger()
class AsanaConnector(LoadConnector, PollConnector):
def __init__(
self,
asana_workspace_id: str,
asana_project_ids: str | None = None,
asana_team_id: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
) -> None:
self.workspace_id = asana_workspace_id
self.project_ids_to_index: list[str] | None = (
asana_project_ids.split(",") if asana_project_ids is not None else None
)
self.asana_team_id = asana_team_id
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
logger.info(
f"AsanaConnector initialized with workspace_id: {asana_workspace_id}"
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.api_token = credentials["asana_api_token_secret"]
self.asana_client = asana_api.AsanaAPI(
api_token=self.api_token,
workspace_gid=self.workspace_id,
team_gid=self.asana_team_id,
)
logger.info("Asana credentials loaded and API client initialized")
return None
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
start_time = datetime.datetime.fromtimestamp(start).isoformat()
logger.info(f"Starting Asana poll from {start_time}")
asana = asana_api.AsanaAPI(
api_token=self.api_token,
workspace_gid=self.workspace_id,
team_gid=self.asana_team_id,
)
docs_batch: list[Document] = []
tasks = asana.get_tasks(self.project_ids_to_index, start_time)
for task in tasks:
doc = self._message_to_doc(task)
docs_batch.append(doc)
if len(docs_batch) >= self.batch_size:
logger.info(f"Yielding batch of {len(docs_batch)} documents")
yield docs_batch
docs_batch = []
if docs_batch:
logger.info(f"Yielding final batch of {len(docs_batch)} documents")
yield docs_batch
logger.info("Asana poll completed")
def load_from_state(self) -> GenerateDocumentsOutput:
logger.notice("Starting full index of all Asana tasks")
return self.poll_source(start=0, end=None)
def _message_to_doc(self, task: asana_api.AsanaTask) -> Document:
logger.debug(f"Converting Asana task {task.id} to Document")
return Document(
id=task.id,
sections=[Section(link=task.link, text=task.text)],
doc_updated_at=task.last_modified,
source=DocumentSource.ASANA,
semantic_identifier=task.title,
metadata={
"group": task.project_gid,
"project": task.project_name,
},
)
if __name__ == "__main__":
import time
import os
logger.notice("Starting Asana connector test")
connector = AsanaConnector(
os.environ["WORKSPACE_ID"],
os.environ["PROJECT_IDS"],
os.environ["TEAM_ID"],
)
connector.load_credentials(
{
"asana_api_token_secret": os.environ["API_TOKEN"],
}
)
logger.info("Loading all documents from Asana")
all_docs = connector.load_from_state()
current = time.time()
one_day_ago = current - 24 * 60 * 60 # 1 day
logger.info("Polling for documents updated in the last 24 hours")
latest_docs = connector.poll_source(one_day_ago, current)
for docs in latest_docs:
for doc in docs:
print(doc.id)
logger.notice("Asana connector test completed")

View File

@@ -7,7 +7,6 @@ from datetime import timezone
from functools import lru_cache
from typing import Any
from typing import cast
from urllib.parse import urlparse
import bs4
from atlassian import Confluence # type:ignore
@@ -53,79 +52,6 @@ NO_PARENT_OR_NO_PERMISSIONS_ERROR_STR = (
)
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]:
"""Sample
URL w/ page: https://danswer.atlassian.net/wiki/spaces/1234abcd/pages/5678efgh/overview
URL w/o page: https://danswer.atlassian.net/wiki/spaces/ASAM/overview
wiki_base is https://danswer.atlassian.net/wiki
space is 1234abcd
page_id is 5678efgh
"""
parsed_url = urlparse(wiki_url)
wiki_base = (
parsed_url.scheme
+ "://"
+ parsed_url.netloc
+ parsed_url.path.split("/spaces")[0]
)
path_parts = parsed_url.path.split("/")
space = path_parts[3]
page_id = path_parts[5] if len(path_parts) > 5 else ""
return wiki_base, space, page_id
def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, str, str]:
"""Sample
URL w/ page https://danswer.ai/confluence/display/1234abcd/pages/5678efgh/overview
URL w/o page https://danswer.ai/confluence/display/1234abcd/overview
wiki_base is https://danswer.ai/confluence
space is 1234abcd
page_id is 5678efgh
"""
# /display/ is always right before the space and at the end of the base print()
DISPLAY = "/display/"
PAGE = "/pages/"
parsed_url = urlparse(wiki_url)
wiki_base = (
parsed_url.scheme
+ "://"
+ parsed_url.netloc
+ parsed_url.path.split(DISPLAY)[0]
)
space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0]
page_id = ""
if (content := parsed_url.path.split(PAGE)) and len(content) > 1:
page_id = content[1]
return wiki_base, space, page_id
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]:
is_confluence_cloud = (
".atlassian.net/wiki/spaces/" in wiki_url
or ".jira.com/wiki/spaces/" in wiki_url
)
try:
if is_confluence_cloud:
wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url(
wiki_url
)
else:
wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url(
wiki_url
)
except Exception as e:
error_msg = f"Not a valid Confluence Wiki Link, unable to extract wiki base, space, and page id. Exception: {e}"
logger.error(error_msg)
raise ValueError(error_msg)
return wiki_base, space, page_id, is_confluence_cloud
@lru_cache()
def _get_user(user_id: str, confluence_client: Confluence) -> str:
"""Get Confluence Display Name based on the account-id or userkey value
@@ -372,7 +298,10 @@ class RecursiveIndexer:
class ConfluenceConnector(LoadConnector, PollConnector):
def __init__(
self,
wiki_page_url: str,
wiki_base: str,
space: str,
is_cloud: bool,
page_id: str = "",
index_recursively: bool = True,
batch_size: int = INDEX_BATCH_SIZE,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
@@ -386,15 +315,15 @@ class ConfluenceConnector(LoadConnector, PollConnector):
self.labels_to_skip = set(labels_to_skip)
self.recursive_indexer: RecursiveIndexer | None = None
self.index_recursively = index_recursively
(
self.wiki_base,
self.space,
self.page_id,
self.is_cloud,
) = extract_confluence_keys_from_url(wiki_page_url)
# Remove trailing slash from wiki_base if present
self.wiki_base = wiki_base.rstrip("/")
self.space = space
self.page_id = page_id
self.is_cloud = is_cloud
self.space_level_scan = False
self.confluence_client: Confluence | None = None
if self.page_id is None or self.page_id == "":
@@ -414,7 +343,6 @@ class ConfluenceConnector(LoadConnector, PollConnector):
username=username if self.is_cloud else None,
password=access_token if self.is_cloud else None,
token=access_token if not self.is_cloud else None,
cloud=self.is_cloud,
)
return None
@@ -866,7 +794,13 @@ class ConfluenceConnector(LoadConnector, PollConnector):
if __name__ == "__main__":
connector = ConfluenceConnector(os.environ["CONFLUENCE_TEST_SPACE_URL"])
connector = ConfluenceConnector(
wiki_base=os.environ["CONFLUENCE_TEST_SPACE_URL"],
space=os.environ["CONFLUENCE_TEST_SPACE"],
is_cloud=os.environ.get("CONFLUENCE_IS_CLOUD", "true").lower() == "true",
page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""),
index_recursively=True,
)
connector.load_credentials(
{
"confluence_username": os.environ["CONFLUENCE_USER_NAME"],

View File

@@ -23,7 +23,7 @@ class ConfluenceRateLimitError(Exception):
def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
max_retries = 10
max_retries = 5
starting_delay = 5
backoff = 2
max_delay = 600
@@ -32,17 +32,24 @@ def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
try:
return confluence_call(*args, **kwargs)
except HTTPError as e:
# Check if the response or headers are None to avoid potential AttributeError
if e.response is None or e.response.headers is None:
logger.warning("HTTPError with `None` as response or as headers")
raise e
retry_after_header = e.response.headers.get("Retry-After")
if (
e.response.status_code == 429
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
):
retry_after = None
try:
retry_after = int(e.response.headers.get("Retry-After"))
except (ValueError, TypeError):
pass
if retry_after_header is not None:
try:
retry_after = int(retry_after_header)
except ValueError:
pass
if retry_after:
if retry_after is not None:
logger.warning(
f"Rate limit hit. Retrying after {retry_after} seconds..."
)

View File

@@ -45,10 +45,15 @@ def extract_jira_project(url: str) -> tuple[str, str]:
return jira_base, jira_project
def extract_text_from_content(content: dict) -> str:
def extract_text_from_adf(adf: dict | None) -> str:
"""Extracts plain text from Atlassian Document Format:
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
WARNING: This function is incomplete and will e.g. skip lists!
"""
texts = []
if "content" in content:
for block in content["content"]:
if adf is not None and "content" in adf:
for block in adf["content"]:
if "content" in block:
for item in block["content"]:
if item["type"] == "text":
@@ -72,18 +77,15 @@ def _get_comment_strs(
comment_strs = []
for comment in jira.fields.comment.comments:
try:
if hasattr(comment, "body"):
body_text = extract_text_from_content(comment.raw["body"])
elif hasattr(comment, "raw"):
body = comment.raw.get("body", "No body content available")
body_text = (
extract_text_from_content(body) if isinstance(body, dict) else body
)
else:
body_text = "No body attribute found"
body_text = (
comment.body
if JIRA_API_VERSION == "2"
else extract_text_from_adf(comment.raw["body"])
)
if (
hasattr(comment, "author")
and hasattr(comment.author, "emailAddress")
and comment.author.emailAddress in comment_email_blacklist
):
continue # Skip adding comment if author's email is in blacklist
@@ -126,11 +128,14 @@ def fetch_jira_issues_batch(
)
continue
description = (
jira.fields.description
if JIRA_API_VERSION == "2"
else extract_text_from_adf(jira.raw["fields"]["description"])
)
comments = _get_comment_strs(jira, comment_email_blacklist)
semantic_rep = (
f"{jira.fields.description}\n"
if jira.fields.description
else "" + "\n".join([f"Comment: {comment}" for comment in comments])
semantic_rep = f"{description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments if comment]
)
page_url = f"{jira_client.client_info()}/browse/{jira.key}"

View File

@@ -4,6 +4,7 @@ from typing import Type
from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource
from danswer.connectors.asana.connector import AsanaConnector
from danswer.connectors.axero.connector import AxeroConnector
from danswer.connectors.blob.connector import BlobStorageConnector
from danswer.connectors.bookstack.connector import BookstackConnector
@@ -91,6 +92,7 @@ def identify_connector_class(
DocumentSource.CLICKUP: ClickupConnector,
DocumentSource.MEDIAWIKI: MediaWikiConnector,
DocumentSource.WIKIPEDIA: WikipediaConnector,
DocumentSource.ASANA: AsanaConnector,
DocumentSource.S3: BlobStorageConnector,
DocumentSource.R2: BlobStorageConnector,
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
@@ -124,11 +126,11 @@ def identify_connector_class(
def instantiate_connector(
db_session: Session,
source: DocumentSource,
input_type: InputType,
connector_specific_config: dict[str, Any],
credential: Credential,
db_session: Session,
) -> BaseConnector:
connector_class = identify_connector_class(source, input_type)
connector = connector_class(**connector_specific_config)

View File

@@ -6,7 +6,6 @@ from datetime import timezone
from enum import Enum
from itertools import chain
from typing import Any
from typing import cast
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
@@ -21,19 +20,13 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.google_drive.connector_auth import (
get_google_drive_creds_for_authorized_user,
)
from danswer.connectors.google_drive.connector_auth import (
get_google_drive_creds_for_service_account,
)
from danswer.connectors.google_drive.connector_auth import get_google_drive_creds
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
@@ -62,6 +55,8 @@ class GDriveMimeType(str, Enum):
POWERPOINT = (
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
)
PLAIN_TEXT = "text/plain"
MARKDOWN = "text/markdown"
GoogleDriveFileType = dict[str, Any]
@@ -316,19 +311,22 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
GDriveMimeType.PPT.value,
GDriveMimeType.SPREADSHEET.value,
]:
export_mime_type = "text/plain"
if mime_type == GDriveMimeType.SPREADSHEET.value:
export_mime_type = "text/csv"
elif mime_type == GDriveMimeType.PPT.value:
export_mime_type = "text/plain"
response = (
export_mime_type = (
"text/plain"
if mime_type != GDriveMimeType.SPREADSHEET.value
else "text/csv"
)
return (
service.files()
.export(fileId=file["id"], mimeType=export_mime_type)
.execute()
.decode("utf-8")
)
return response.decode("utf-8")
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,
]:
return service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
elif mime_type == GDriveMimeType.WORD_DOC.value:
response = service.files().get_media(fileId=file["id"]).execute()
return docx_to_text(file=io.BytesIO(response))
@@ -402,42 +400,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.
"""
creds: OAuthCredentials | ServiceAccountCredentials | None = None
new_creds_dict = None
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
access_token_json_str = cast(
str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]
)
creds = get_google_drive_creds_for_authorized_user(
token_json_str=access_token_json_str
)
# tell caller to update token stored in DB if it has changed
# (e.g. the token has been refreshed)
new_creds_json_str = creds.to_json() if creds else ""
if new_creds_json_str != access_token_json_str:
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
if DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
service_account_key_json_str = credentials[
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
]
creds = get_google_drive_creds_for_service_account(
service_account_key_json_str=service_account_key_json_str
)
# "Impersonate" a user if one is specified
delegated_user_email = cast(
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
)
if delegated_user_email:
creds = creds.with_subject(delegated_user_email) if creds else None # type: ignore
if creds is None:
raise PermissionError(
"Unable to access Google Drive - unknown credential structure."
)
creds, new_creds_dict = get_google_drive_creds(credentials)
self.creds = creds
return new_creds_dict
@@ -504,6 +467,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
file["modifiedTime"]
).astimezone(timezone.utc),
metadata={} if text_contents else {IGNORE_FOR_QA: "True"},
additional_info=file.get("id"),
)
)
except Exception as e:

View File

@@ -10,11 +10,13 @@ from google.oauth2.service_account import Credentials as ServiceAccountCredentia
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from sqlalchemy.orm import Session
from danswer.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import KV_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from danswer.connectors.google_drive.constants import BASE_SCOPES
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
@@ -22,7 +24,8 @@ from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.connectors.google_drive.constants import SCOPES
from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
from danswer.db.credentials import update_credential_json
from danswer.db.models import User
from danswer.dynamic_configs.factory import get_dynamic_config_store
@@ -34,15 +37,25 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def build_gdrive_scopes() -> list[str]:
base_scopes: list[str] = BASE_SCOPES
permissions_scopes: list[str] = FETCH_PERMISSIONS_SCOPES
groups_scopes: list[str] = FETCH_GROUPS_SCOPES
if ENTERPRISE_EDITION_ENABLED:
return base_scopes + permissions_scopes + groups_scopes
return base_scopes + permissions_scopes
def _build_frontend_google_drive_redirect() -> str:
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
def get_google_drive_creds_for_authorized_user(
token_json_str: str,
token_json_str: str, scopes: list[str] = build_gdrive_scopes()
) -> OAuthCredentials | None:
creds_json = json.loads(token_json_str)
creds = OAuthCredentials.from_authorized_user_info(creds_json, SCOPES)
creds = OAuthCredentials.from_authorized_user_info(creds_json, scopes)
if creds.valid:
return creds
@@ -59,18 +72,67 @@ def get_google_drive_creds_for_authorized_user(
return None
def get_google_drive_creds_for_service_account(
service_account_key_json_str: str,
def _get_google_drive_creds_for_service_account(
service_account_key_json_str: str, scopes: list[str] = build_gdrive_scopes()
) -> ServiceAccountCredentials | None:
service_account_key = json.loads(service_account_key_json_str)
creds = ServiceAccountCredentials.from_service_account_info(
service_account_key, scopes=SCOPES
service_account_key, scopes=scopes
)
if not creds.valid or not creds.expired:
creds.refresh(Request())
return creds if creds.valid else None
def get_google_drive_creds(
credentials: dict[str, str], scopes: list[str] = build_gdrive_scopes()
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
oauth_creds = None
service_creds = None
new_creds_dict = None
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY])
oauth_creds = get_google_drive_creds_for_authorized_user(
token_json_str=access_token_json_str, scopes=scopes
)
# tell caller to update token stored in DB if it has changed
# (e.g. the token has been refreshed)
new_creds_json_str = oauth_creds.to_json() if oauth_creds else ""
if new_creds_json_str != access_token_json_str:
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
service_account_key_json_str = credentials[
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
]
service_creds = _get_google_drive_creds_for_service_account(
service_account_key_json_str=service_account_key_json_str,
scopes=scopes,
)
# "Impersonate" a user if one is specified
delegated_user_email = cast(
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
)
if delegated_user_email:
service_creds = (
service_creds.with_subject(delegated_user_email)
if service_creds
else None
)
creds: ServiceAccountCredentials | OAuthCredentials | None = (
oauth_creds or service_creds
)
if creds is None:
raise PermissionError(
"Unable to access Google Drive - unknown credential structure."
)
return creds, new_creds_dict
def verify_csrf(credential_id: int, state: str) -> None:
csrf = get_dynamic_config_store().load(KV_CRED_KEY.format(str(credential_id)))
if csrf != state:
@@ -84,7 +146,7 @@ def get_auth_url(credential_id: int) -> str:
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
scopes=SCOPES,
scopes=build_gdrive_scopes(),
redirect_uri=_build_frontend_google_drive_redirect(),
)
auth_url, _ = flow.authorization_url(prompt="consent")
@@ -107,7 +169,7 @@ def update_credential_access_tokens(
app_credentials = get_google_app_cred()
flow = InstalledAppFlow.from_client_config(
app_credentials.model_dump(),
scopes=SCOPES,
scopes=build_gdrive_scopes(),
redirect_uri=_build_frontend_google_drive_redirect(),
)
flow.fetch_token(code=auth_code)

View File

@@ -1,7 +1,7 @@
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens"
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user"
SCOPES = [
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive.metadata.readonly",
]
BASE_SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
FETCH_PERMISSIONS_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"]
FETCH_GROUPS_SCOPES = ["https://www.googleapis.com/auth/cloud-identity.groups.readonly"]

View File

@@ -113,6 +113,9 @@ class DocumentBase(BaseModel):
# The default title is semantic_identifier though unless otherwise specified
title: str | None = None
from_ingestion_api: bool = False
# Anything else that may be useful that is specific to this particular connector type that other
# parts of the code may need. If you're unsure, this can be left as None
additional_info: Any = None
def get_title_for_document_index(
self,

View File

@@ -237,6 +237,14 @@ class NotionConnector(LoadConnector, PollConnector):
)
continue
if result_type == "external_object_instance_page":
logger.warning(
f"Skipping 'external_object_instance_page' ('{result_block_id}') for base block '{base_block_id}': "
f"Notion API does not currently support reading external blocks (as of 24/07/03) "
f"(discussion: https://github.com/danswer-ai/danswer/issues/1761)"
)
continue
cur_result_text_arr = []
if "rich_text" in result_obj:
for rich_text in result_obj["rich_text"]:

View File

@@ -98,6 +98,15 @@ class ProductboardConnector(PollConnector):
owner = self._get_owner_email(feature)
experts = [BasicExpertInfo(email=owner)] if owner else None
metadata: dict[str, str | list[str]] = {}
entity_type = feature.get("type", "feature")
if entity_type:
metadata["entity_type"] = str(entity_type)
status = feature.get("status", {}).get("name")
if status:
metadata["status"] = str(status)
yield Document(
id=feature["id"],
sections=[
@@ -110,10 +119,7 @@ class ProductboardConnector(PollConnector):
source=DocumentSource.PRODUCTBOARD,
doc_updated_at=time_str_to_utc(feature["updatedAt"]),
primary_owners=experts,
metadata={
"entity_type": feature["type"],
"status": feature["status"]["name"],
},
metadata=metadata,
)
def _get_components(self) -> Generator[Document, None, None]:
@@ -174,6 +180,12 @@ class ProductboardConnector(PollConnector):
owner = self._get_owner_email(objective)
experts = [BasicExpertInfo(email=owner)] if owner else None
metadata: dict[str, str | list[str]] = {
"entity_type": "objective",
}
if objective.get("state"):
metadata["state"] = str(objective["state"])
yield Document(
id=objective["id"],
sections=[
@@ -186,10 +198,7 @@ class ProductboardConnector(PollConnector):
source=DocumentSource.PRODUCTBOARD,
doc_updated_at=time_str_to_utc(objective["updatedAt"]),
primary_owners=experts,
metadata={
"entity_type": "release",
"state": objective["state"],
},
metadata=metadata,
)
def _is_updated_at_out_of_time_range(

View File

@@ -25,7 +25,6 @@ from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import extract_file_text
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -137,7 +136,7 @@ class SharepointConnector(LoadConnector, PollConnector):
.execute_query()
]
else:
sites = self.graph_client.sites.get().execute_query()
sites = self.graph_client.sites.get_all().execute_query()
self.site_data = [
SiteData(url=None, folder=None, sites=sites, driveitems=[])
]

View File

@@ -29,6 +29,7 @@ from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import SlackTextCleaner
from danswer.utils.logger import setup_logger
logger = setup_logger()

View File

@@ -1,6 +1,8 @@
import io
import ipaddress
import socket
from datetime import datetime
from datetime import timezone
from enum import Enum
from typing import Any
from typing import cast
@@ -85,7 +87,8 @@ def check_internet_connection(url: str) -> None:
response = requests.get(url, timeout=3)
response.raise_for_status()
except requests.exceptions.HTTPError as e:
status_code = e.response.status_code
# Extract status code from the response, defaulting to -1 if response is None
status_code = e.response.status_code if e.response is not None else -1
error_msg = {
400: "Bad Request",
401: "Unauthorized",
@@ -202,6 +205,15 @@ def _read_urls_file(location: str) -> list[str]:
return urls
def _get_datetime_from_last_modified_header(last_modified: str) -> datetime | None:
try:
return datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z").replace(
tzinfo=timezone.utc
)
except (ValueError, TypeError):
return None
class WebConnector(LoadConnector):
def __init__(
self,
@@ -287,6 +299,7 @@ class WebConnector(LoadConnector):
page_text, metadata = read_pdf_file(
file=io.BytesIO(response.content)
)
last_modified = response.headers.get("Last-Modified")
doc_batch.append(
Document(
@@ -295,12 +308,22 @@ class WebConnector(LoadConnector):
source=DocumentSource.WEB,
semantic_identifier=current_url.split("/")[-1],
metadata=metadata,
doc_updated_at=_get_datetime_from_last_modified_header(
last_modified
)
if last_modified
else None,
)
)
continue
page = context.new_page()
page_response = page.goto(current_url)
last_modified = (
page_response.header_value("Last-Modified")
if page_response
else None
)
final_page = page.url
if final_page != current_url:
logger.info(f"Redirected to {final_page}")
@@ -336,6 +359,11 @@ class WebConnector(LoadConnector):
source=DocumentSource.WEB,
semantic_identifier=parsed_html.title or current_url,
metadata={},
doc_updated_at=_get_datetime_from_last_modified_header(
last_modified
)
if last_modified
else None,
)
)

View File

@@ -3,6 +3,7 @@ from typing import Any
import requests
from retry import retry
from zenpy import Zenpy # type: ignore
from zenpy.lib.api_objects import Ticket # type: ignore
from zenpy.lib.api_objects.help_centre_objects import Article # type: ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE
@@ -59,10 +60,15 @@ class ZendeskClientNotSetUpError(PermissionError):
class ZendeskConnector(LoadConnector, PollConnector):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
content_type: str = "articles",
) -> None:
self.batch_size = batch_size
self.zendesk_client: Zenpy | None = None
self.content_tags: dict[str, str] = {}
self.content_type = content_type
@retry(tries=3, delay=2, backoff=2)
def _set_content_tags(
@@ -122,16 +128,86 @@ class ZendeskConnector(LoadConnector, PollConnector):
def load_from_state(self) -> GenerateDocumentsOutput:
return self.poll_source(None, None)
def _ticket_to_document(self, ticket: Ticket) -> Document:
if self.zendesk_client is None:
raise ZendeskClientNotSetUpError()
owner = None
if ticket.requester and ticket.requester.name and ticket.requester.email:
owner = [
BasicExpertInfo(
display_name=ticket.requester.name, email=ticket.requester.email
)
]
update_time = time_str_to_utc(ticket.updated_at) if ticket.updated_at else None
metadata: dict[str, str | list[str]] = {}
if ticket.status is not None:
metadata["status"] = ticket.status
if ticket.priority is not None:
metadata["priority"] = ticket.priority
if ticket.tags:
metadata["tags"] = ticket.tags
if ticket.type is not None:
metadata["ticket_type"] = ticket.type
# Fetch comments for the ticket
comments = self.zendesk_client.tickets.comments(ticket=ticket)
# Combine all comments into a single text
comments_text = "\n\n".join(
[
f"Comment{f' by {comment.author.name}' if comment.author and comment.author.name else ''}"
f"{f' at {comment.created_at}' if comment.created_at else ''}:\n{comment.body}"
for comment in comments
if comment.body
]
)
# Combine ticket description and comments
description = (
ticket.description
if hasattr(ticket, "description") and ticket.description
else ""
)
full_text = f"Ticket Description:\n{description}\n\nComments:\n{comments_text}"
# Extract subdomain from ticket.url
subdomain = ticket.url.split("//")[1].split(".zendesk.com")[0]
# Build the html url for the ticket
ticket_url = f"https://{subdomain}.zendesk.com/agent/tickets/{ticket.id}"
return Document(
id=f"zendesk_ticket_{ticket.id}",
sections=[Section(link=ticket_url, text=full_text)],
source=DocumentSource.ZENDESK,
semantic_identifier=f"Ticket #{ticket.id}: {ticket.subject or 'No Subject'}",
doc_updated_at=update_time,
primary_owners=owner,
metadata=metadata,
)
def poll_source(
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
if self.zendesk_client is None:
raise ZendeskClientNotSetUpError()
if self.content_type == "articles":
yield from self._poll_articles(start)
elif self.content_type == "tickets":
yield from self._poll_tickets(start)
else:
raise ValueError(f"Unsupported content_type: {self.content_type}")
def _poll_articles(
self, start: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
articles = (
self.zendesk_client.help_center.articles(cursor_pagination=True)
self.zendesk_client.help_center.articles(cursor_pagination=True) # type: ignore
if start is None
else self.zendesk_client.help_center.articles.incremental(
else self.zendesk_client.help_center.articles.incremental( # type: ignore
start_time=int(start)
)
)
@@ -155,9 +231,43 @@ class ZendeskConnector(LoadConnector, PollConnector):
if doc_batch:
yield doc_batch
def _poll_tickets(
self, start: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
if self.zendesk_client is None:
raise ZendeskClientNotSetUpError()
ticket_generator = self.zendesk_client.tickets.incremental(start_time=start)
while True:
doc_batch = []
for _ in range(self.batch_size):
try:
ticket = next(ticket_generator)
# Check if the ticket status is deleted and skip it if so
if ticket.status == "deleted":
continue
doc_batch.append(self._ticket_to_document(ticket))
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch.clear()
except StopIteration:
# No more tickets to process
if doc_batch:
yield doc_batch
return
if doc_batch:
yield doc_batch
if __name__ == "__main__":
import os
import time
connector = ZendeskConnector()

View File

@@ -25,7 +25,6 @@ from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
from danswer.danswerbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.icons import source_to_github_img_link
@@ -360,22 +359,6 @@ def build_quotes_block(
return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))]
def build_standard_answer_blocks(
answer_message: str,
) -> list[Block]:
generate_button_block = ButtonElement(
action_id=GENERATE_ANSWER_BUTTON_ACTION_ID,
text="Generate Full Answer",
)
answer_block = SectionBlock(text=answer_message)
return [
answer_block,
ActionsBlock(
elements=[generate_button_block],
),
]
def build_qa_response_blocks(
message_id: int | None,
answer: str | None,

View File

@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
from danswer.configs.constants import MessageType
from danswer.configs.constants import SearchFeedbackType
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
from danswer.connectors.slack.utils import expert_info_from_slack_id
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.danswerbot.slack.blocks import build_follow_up_resolved_blocks
from danswer.danswerbot.slack.blocks import get_document_feedback_blocks
@@ -87,6 +88,8 @@ def handle_generate_answer_button(
message_ts = req.payload["message"]["ts"]
thread_ts = req.payload["container"]["thread_ts"]
user_id = req.payload["user"]["id"]
expert_info = expert_info_from_slack_id(user_id, client.web_client, user_cache={})
email = expert_info.email if expert_info else None
if not thread_ts:
raise ValueError("Missing thread_ts in the payload")
@@ -125,6 +128,7 @@ def handle_generate_answer_button(
msg_to_respond=cast(str, message_ts or thread_ts),
thread_to_respond=cast(str, thread_ts or message_ts),
sender=user_id or None,
email=email or None,
bypass_filters=True,
is_bot_msg=False,
is_bot_dm=False,

View File

@@ -21,6 +21,7 @@ from danswer.danswerbot.slack.utils import slack_usage_report
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import SlackBotConfig
from danswer.db.users import add_non_web_user_if_not_exists
from danswer.utils.logger import setup_logger
from shared_configs.configs import SLACK_CHANNEL_ID
@@ -209,6 +210,9 @@ def handle_message(
logger.error(f"Was not able to react to user message due to: {e}")
with Session(get_sqlalchemy_engine()) as db_session:
if message_info.email:
add_non_web_user_if_not_exists(db_session, message_info.email)
# first check if we need to respond with a standard answer
used_standard_answer = handle_standard_answers(
message_info=message_info,

View File

@@ -5,6 +5,7 @@ from typing import cast
from typing import Optional
from typing import TypeVar
from fastapi import HTTPException
from retry import retry
from slack_sdk import WebClient
from slack_sdk.models.blocks import DividerBlock
@@ -38,6 +39,7 @@ from danswer.db.models import SlackBotConfig
from danswer.db.models import SlackBotResponseType
from danswer.db.persona import fetch_persona_by_id
from danswer.db.search_settings import get_current_search_settings
from danswer.db.users import get_user_by_email
from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona,
)
@@ -99,6 +101,12 @@ def handle_regular_answer(
messages = message_info.thread_messages
message_ts_to_respond_to = message_info.msg_to_respond
is_bot_msg = message_info.is_bot_msg
user = None
if message_info.is_bot_dm:
if message_info.email:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
user = get_user_by_email(message_info.email, db_session)
document_set_names: list[str] | None = None
persona = slack_bot_config.persona if slack_bot_config else None
@@ -128,7 +136,8 @@ def handle_regular_answer(
else slack_bot_config.response_type == SlackBotResponseType.CITATIONS
)
if not message_ts_to_respond_to:
if not message_ts_to_respond_to and not is_bot_msg:
# if the message is not "/danswer" command, then it should have a message ts to respond to
raise RuntimeError(
"No message timestamp to respond to in `handle_message`. This should never happen."
)
@@ -145,15 +154,23 @@ def handle_regular_answer(
with Session(get_sqlalchemy_engine()) as db_session:
if len(new_message_request.messages) > 1:
persona = cast(
Persona,
fetch_persona_by_id(
db_session,
new_message_request.persona_id,
user=None,
get_editable=False,
),
)
if new_message_request.persona_config:
raise HTTPException(
status_code=403,
detail="Slack bot does not support persona config",
)
elif new_message_request.persona_id:
persona = cast(
Persona,
fetch_persona_by_id(
db_session,
new_message_request.persona_id,
user=None,
get_editable=False,
),
)
llm, _ = get_llms_for_persona(persona)
# In cases of threads, split the available tokens between docs and thread context
@@ -185,7 +202,7 @@ def handle_regular_answer(
# This also handles creating the query event in postgres
answer = get_search_answer(
query_req=new_message_request,
user=None,
user=user,
max_document_tokens=max_document_tokens,
max_history_tokens=max_history_tokens,
db_session=db_session,
@@ -412,7 +429,7 @@ def handle_regular_answer(
)
# Get the chunks fed to the LLM only, then fill with other docs
llm_doc_inds = answer.llm_chunks_indices or []
llm_doc_inds = answer.llm_selected_doc_indices or []
llm_docs = [top_docs[i] for i in llm_doc_inds]
remaining_docs = [
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
@@ -463,7 +480,9 @@ def handle_regular_answer(
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
if receiver_ids:
# if there is no message_ts_to_respond_to, and we have made it this far, then this is a /danswer message
# so we shouldn't send_team_member_message
if receiver_ids and message_ts_to_respond_to is not None:
send_team_member_message(
client=client,
channel=channel,

View File

@@ -1,60 +1,16 @@
from slack_sdk import WebClient
from sqlalchemy.orm import Session
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
from danswer.danswerbot.slack.blocks import build_standard_answer_blocks
from danswer.danswerbot.slack.blocks import get_restate_blocks
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.chat import create_chat_session
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_chat_messages_by_sessions
from danswer.db.chat import get_chat_sessions_by_slack_thread_id
from danswer.db.chat import get_or_create_root_message
from danswer.db.models import Prompt
from danswer.db.models import SlackBotConfig
from danswer.db.standard_answer import fetch_standard_answer_categories_by_names
from danswer.db.standard_answer import find_matching_standard_answers
from danswer.server.manage.models import StandardAnswer
from danswer.utils.logger import DanswerLoggingAdapter
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
logger = setup_logger()
def oneoff_standard_answers(
message: str,
slack_bot_categories: list[str],
db_session: Session,
) -> list[StandardAnswer]:
"""
Respond to the user message if it matches any configured standard answers.
Returns a list of matching StandardAnswers if found, otherwise None.
"""
configured_standard_answers = {
standard_answer
for category in fetch_standard_answer_categories_by_names(
slack_bot_categories, db_session=db_session
)
for standard_answer in category.standard_answers
}
matching_standard_answers = find_matching_standard_answers(
query=message,
id_in=[answer.id for answer in configured_standard_answers],
db_session=db_session,
)
server_standard_answers = [
StandardAnswer.from_model(db_answer) for db_answer in matching_standard_answers
]
return server_standard_answers
def handle_standard_answers(
message_info: SlackMessageInfo,
receiver_ids: list[str] | None,
@@ -63,153 +19,38 @@ def handle_standard_answers(
logger: DanswerLoggingAdapter,
client: WebClient,
db_session: Session,
) -> bool:
"""Returns whether one or more Standard Answer message blocks were
emitted by the Slack bot"""
versioned_handle_standard_answers = fetch_versioned_implementation(
"danswer.danswerbot.slack.handlers.handle_standard_answers",
"_handle_standard_answers",
)
return versioned_handle_standard_answers(
message_info=message_info,
receiver_ids=receiver_ids,
slack_bot_config=slack_bot_config,
prompt=prompt,
logger=logger,
client=client,
db_session=db_session,
)
def _handle_standard_answers(
message_info: SlackMessageInfo,
receiver_ids: list[str] | None,
slack_bot_config: SlackBotConfig | None,
prompt: Prompt | None,
logger: DanswerLoggingAdapter,
client: WebClient,
db_session: Session,
) -> bool:
"""
Potentially respond to the user message depending on whether the user's message matches
any of the configured standard answers and also whether those answers have already been
provided in the current thread.
Standard Answers are a paid Enterprise Edition feature. This is the fallback
function handling the case where EE features are not enabled.
Returns True if standard answers are found to match the user's message and therefore,
we still need to respond to the users.
Always returns false i.e. since EE features are not enabled, we NEVER create any
Slack message blocks.
"""
# if no channel config, then no standard answers are configured
if not slack_bot_config:
return False
slack_thread_id = message_info.thread_to_respond
configured_standard_answer_categories = (
slack_bot_config.standard_answer_categories if slack_bot_config else []
)
configured_standard_answers = set(
[
standard_answer
for standard_answer_category in configured_standard_answer_categories
for standard_answer in standard_answer_category.standard_answers
]
)
query_msg = message_info.thread_messages[-1]
if slack_thread_id is None:
used_standard_answer_ids = set([])
else:
chat_sessions = get_chat_sessions_by_slack_thread_id(
slack_thread_id=slack_thread_id,
user_id=None,
db_session=db_session,
)
chat_messages = get_chat_messages_by_sessions(
chat_session_ids=[chat_session.id for chat_session in chat_sessions],
user_id=None,
db_session=db_session,
skip_permission_check=True,
)
used_standard_answer_ids = set(
[
standard_answer.id
for chat_message in chat_messages
for standard_answer in chat_message.standard_answers
]
)
usable_standard_answers = configured_standard_answers.difference(
used_standard_answer_ids
)
if usable_standard_answers:
matching_standard_answers = find_matching_standard_answers(
query=query_msg.message,
id_in=[standard_answer.id for standard_answer in usable_standard_answers],
db_session=db_session,
)
else:
matching_standard_answers = []
if matching_standard_answers:
chat_session = create_chat_session(
db_session=db_session,
description="",
user_id=None,
persona_id=slack_bot_config.persona.id if slack_bot_config.persona else 0,
danswerbot_flow=True,
slack_thread_id=slack_thread_id,
one_shot=True,
)
root_message = get_or_create_root_message(
chat_session_id=chat_session.id, db_session=db_session
)
new_user_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=root_message,
prompt_id=prompt.id if prompt else None,
message=query_msg.message,
token_count=0,
message_type=MessageType.USER,
db_session=db_session,
commit=True,
)
formatted_answers = []
for standard_answer in matching_standard_answers:
block_quotified_answer = ">" + standard_answer.answer.replace("\n", "\n> ")
formatted_answer = (
f'Since you mentioned _"{standard_answer.keyword}"_, '
f"I thought this might be useful: \n\n{block_quotified_answer}"
)
formatted_answers.append(formatted_answer)
answer_message = "\n\n".join(formatted_answers)
_ = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=new_user_message,
prompt_id=prompt.id if prompt else None,
message=answer_message,
token_count=0,
message_type=MessageType.ASSISTANT,
error=None,
db_session=db_session,
commit=True,
)
update_emote_react(
emoji=DANSWER_REACT_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=True,
client=client,
)
restate_question_blocks = get_restate_blocks(
msg=query_msg.message,
is_bot_msg=message_info.is_bot_msg,
)
answer_blocks = build_standard_answer_blocks(
answer_message=answer_message,
)
all_blocks = restate_question_blocks + answer_blocks
try:
respond_in_thread(
client=client,
channel=message_info.channel_to_respond,
receiver_ids=receiver_ids,
text="Hello! Danswer has some results for you!",
blocks=all_blocks,
thread_ts=message_info.msg_to_respond,
unfurl=False,
)
if receiver_ids and slack_thread_id:
send_team_member_message(
client=client,
channel=message_info.channel_to_respond,
thread_ts=slack_thread_id,
)
return True
except Exception as e:
logger.exception(f"Unable to send standard answer message: {e}")
return False
else:
return False
return False

View File

@@ -13,6 +13,7 @@ from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
from danswer.connectors.slack.utils import expert_info_from_slack_id
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
@@ -55,6 +56,7 @@ from danswer.one_shot_answer.models import ThreadMessage
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import SLACK_CHANNEL_ID
@@ -256,6 +258,11 @@ def build_request_details(
tagged = event.get("type") == "app_mention"
message_ts = event.get("ts")
thread_ts = event.get("thread_ts")
sender = event.get("user") or None
expert_info = expert_info_from_slack_id(
sender, client.web_client, user_cache={}
)
email = expert_info.email if expert_info else None
msg = remove_danswer_bot_tag(msg, client=client.web_client)
@@ -286,7 +293,8 @@ def build_request_details(
channel_to_respond=channel,
msg_to_respond=cast(str, message_ts or thread_ts),
thread_to_respond=cast(str, thread_ts or message_ts),
sender=event.get("user") or None,
sender=sender,
email=email,
bypass_filters=tagged,
is_bot_msg=False,
is_bot_dm=event.get("channel_type") == "im",
@@ -296,6 +304,10 @@ def build_request_details(
channel = req.payload["channel_id"]
msg = req.payload["text"]
sender = req.payload["user_id"]
expert_info = expert_info_from_slack_id(
sender, client.web_client, user_cache={}
)
email = expert_info.email if expert_info else None
single_msg = ThreadMessage(message=msg, sender=None, role=MessageType.USER)
@@ -305,6 +317,7 @@ def build_request_details(
msg_to_respond=None,
thread_to_respond=None,
sender=sender,
email=email,
bypass_filters=True,
is_bot_msg=True,
is_bot_dm=False,
@@ -469,6 +482,8 @@ if __name__ == "__main__":
slack_bot_tokens: SlackBotTokens | None = None
socket_client: SocketModeClient | None = None
set_is_ee_based_on_env_variable()
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
download_nltk_data()

View File

@@ -9,6 +9,7 @@ class SlackMessageInfo(BaseModel):
msg_to_respond: str | None
thread_to_respond: str | None
sender: str | None
email: str | None
bypass_filters: bool # User has tagged @DanswerBot
is_bot_msg: bool # User is using /DanswerBot
is_bot_dm: bool # User is direct messaging to DanswerBot

View File

@@ -28,7 +28,7 @@ def get_default_admin_user_emails() -> list[str]:
get_default_admin_user_emails_fn: Callable[
[], list[str]
] = fetch_versioned_implementation_with_fallback(
"danswer.auth.users", "get_default_admin_user_emails_", lambda: []
"danswer.auth.users", "get_default_admin_user_emails_", lambda: list[str]()
)
return get_default_admin_user_emails_fn()

View File

@@ -3,7 +3,6 @@ from datetime import datetime
from datetime import timedelta
from uuid import UUID
from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import func
@@ -87,29 +86,57 @@ def get_chat_sessions_by_slack_thread_id(
return db_session.scalars(stmt).all()
def get_first_messages_for_chat_sessions(
chat_session_ids: list[int], db_session: Session
def get_valid_messages_from_query_sessions(
chat_session_ids: list[int],
db_session: Session,
) -> dict[int, str]:
subquery = (
select(ChatMessage.chat_session_id, func.min(ChatMessage.id).label("min_id"))
user_message_subquery = (
select(
ChatMessage.chat_session_id, func.min(ChatMessage.id).label("user_msg_id")
)
.where(
and_(
ChatMessage.chat_session_id.in_(chat_session_ids),
ChatMessage.message_type == MessageType.USER, # Select USER messages
)
ChatMessage.chat_session_id.in_(chat_session_ids),
ChatMessage.message_type == MessageType.USER,
)
.group_by(ChatMessage.chat_session_id)
.subquery()
)
query = select(ChatMessage.chat_session_id, ChatMessage.message).join(
subquery,
(ChatMessage.chat_session_id == subquery.c.chat_session_id)
& (ChatMessage.id == subquery.c.min_id),
assistant_message_subquery = (
select(
ChatMessage.chat_session_id,
func.min(ChatMessage.id).label("assistant_msg_id"),
)
.where(
ChatMessage.chat_session_id.in_(chat_session_ids),
ChatMessage.message_type == MessageType.ASSISTANT,
)
.group_by(ChatMessage.chat_session_id)
.subquery()
)
query = (
select(ChatMessage.chat_session_id, ChatMessage.message)
.join(
user_message_subquery,
ChatMessage.chat_session_id == user_message_subquery.c.chat_session_id,
)
.join(
assistant_message_subquery,
ChatMessage.chat_session_id == assistant_message_subquery.c.chat_session_id,
)
.join(
ChatMessage__SearchDoc,
ChatMessage__SearchDoc.chat_message_id
== assistant_message_subquery.c.assistant_msg_id,
)
.where(ChatMessage.id == user_message_subquery.c.user_msg_id)
)
first_messages = db_session.execute(query).all()
return dict([(row.chat_session_id, row.message) for row in first_messages])
logger.info(f"Retrieved {len(first_messages)} first messages with documents")
return {row.chat_session_id: row.message for row in first_messages}
def get_chat_sessions_by_user(
@@ -199,7 +226,7 @@ def create_chat_session(
db_session: Session,
description: str,
user_id: UUID | None,
persona_id: int,
persona_id: int | None, # Can be none if temporary persona is used
llm_override: LLMOverride | None = None,
prompt_override: PromptOverride | None = None,
one_shot: bool = False,
@@ -253,6 +280,13 @@ def delete_chat_session(
db_session: Session,
hard_delete: bool = HARD_DELETE_CHATS,
) -> None:
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
)
if chat_session.deleted:
raise ValueError("Cannot delete an already deleted chat session")
if hard_delete:
delete_messages_and_files_from_chat_session(chat_session_id, db_session)
db_session.execute(delete(ChatSession).where(ChatSession.id == chat_session_id))
@@ -564,6 +598,7 @@ def get_doc_query_identifiers_from_model(
chat_session: ChatSession,
user_id: UUID | None,
db_session: Session,
enforce_chat_session_id_for_search_docs: bool,
) -> list[tuple[str, int]]:
"""Given a list of search_doc_ids"""
search_docs = (
@@ -583,7 +618,8 @@ def get_doc_query_identifiers_from_model(
for doc in search_docs
]
):
raise ValueError("Invalid reference doc, not from this chat session.")
if enforce_chat_session_id_for_search_docs:
raise ValueError("Invalid reference doc, not from this chat session.")
except IndexError:
# This happens when the doc has no chat_messages associated with it.
# which happens as an edge case where the chat message failed to save

View File

@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource
from danswer.db.connector import fetch_connector_by_id
from danswer.db.credentials import fetch_credential_by_id
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
@@ -24,6 +25,10 @@ from danswer.db.models import UserGroup__ConnectorCredentialPair
from danswer.db.models import UserRole
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit
from ee.danswer.external_permissions.permission_sync_function_map import (
check_if_valid_sync_source,
)
logger = setup_logger()
@@ -74,7 +79,7 @@ def _add_user_filters(
.correlate(ConnectorCredentialPair)
)
else:
where_clause |= ConnectorCredentialPair.is_public == True # noqa: E712
where_clause |= ConnectorCredentialPair.access_type == AccessType.PUBLIC
return stmt.where(where_clause)
@@ -94,8 +99,19 @@ def get_connector_credential_pairs(
) # noqa
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
results = db_session.scalars(stmt)
return list(results.all())
return list(db_session.scalars(stmt).all())
def add_deletion_failure_message(
db_session: Session,
cc_pair_id: int,
failure_message: str,
) -> None:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if not cc_pair:
return
cc_pair.deletion_failure_message = failure_message
db_session.commit()
def get_cc_pair_groups_for_ids(
@@ -297,9 +313,9 @@ def associate_default_cc_pair(db_session: Session) -> None:
association = ConnectorCredentialPair(
connector_id=0,
credential_id=0,
access_type=AccessType.PUBLIC,
name="DefaultCCPair",
status=ConnectorCredentialPairStatus.ACTIVE,
is_public=True,
)
db_session.add(association)
db_session.commit()
@@ -324,8 +340,9 @@ def add_credential_to_connector(
connector_id: int,
credential_id: int,
cc_pair_name: str | None,
is_public: bool,
access_type: AccessType,
groups: list[int] | None,
auto_sync_options: dict | None = None,
) -> StatusResponse:
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id(credential_id, user, db_session)
@@ -333,10 +350,21 @@ def add_credential_to_connector(
if connector is None:
raise HTTPException(status_code=404, detail="Connector does not exist")
if access_type == AccessType.SYNC:
if not check_if_valid_sync_source(connector.source):
raise HTTPException(
status_code=400,
detail=f"Connector of type {connector.source} does not support SYNC access type",
)
if credential is None:
error_msg = (
f"Credential {credential_id} does not exist or does not belong to user"
)
logger.error(error_msg)
raise HTTPException(
status_code=401,
detail="Credential does not exist or does not belong to user",
detail=error_msg,
)
existing_association = (
@@ -350,7 +378,7 @@ def add_credential_to_connector(
if existing_association is not None:
return StatusResponse(
success=False,
message=f"Connector already has Credential {credential_id}",
message=f"Connector {connector_id} already has Credential {credential_id}",
data=connector_id,
)
@@ -359,12 +387,13 @@ def add_credential_to_connector(
credential_id=credential_id,
name=cc_pair_name,
status=ConnectorCredentialPairStatus.ACTIVE,
is_public=is_public,
access_type=access_type,
auto_sync_options=auto_sync_options,
)
db_session.add(association)
db_session.flush() # make sure the association has an id
if groups:
if groups and access_type != AccessType.SYNC:
_relate_groups_to_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=association.id,
@@ -374,8 +403,8 @@ def add_credential_to_connector(
db_session.commit()
return StatusResponse(
success=False,
message=f"Connector already has Credential {credential_id}",
success=True,
message=f"Creating new association between Connector {connector_id} and Credential {credential_id}",
data=association.id,
)
@@ -407,6 +436,10 @@ def remove_credential_from_connector(
)
if association is not None:
delete_user__ext_group_for_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=association.id,
)
db_session.delete(association)
db_session.commit()
return StatusResponse(

View File

@@ -3,26 +3,30 @@ import time
from collections.abc import Generator
from collections.abc import Sequence
from datetime import datetime
from uuid import UUID
from datetime import timezone
from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import exists
from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.engine.util import TransactionalContext
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import null
from danswer.configs.constants import DEFAULT_BOOST
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.feedback import delete_document_feedback_for_documents__no_commit
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential
from danswer.db.models import Document as DbDocument
from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.models import User
from danswer.db.tag import delete_document_tags_for_documents__no_commit
from danswer.db.utils import model_to_dict
from danswer.document_index.interfaces import DocumentMetadata
@@ -38,6 +42,68 @@ def check_docs_exist(db_session: Session) -> bool:
return result.scalar() or False
def count_documents_by_needs_sync(session: Session) -> int:
"""Get the count of all documents where:
1. last_modified is newer than last_synced
2. last_synced is null (meaning we've never synced)
This function executes the query and returns the count of
documents matching the criteria."""
count = (
session.query(func.count())
.select_from(DbDocument)
.filter(
or_(
DbDocument.last_modified > DbDocument.last_synced,
DbDocument.last_synced.is_(None),
)
)
.scalar()
)
return count
def construct_document_select_for_connector_credential_pair_by_needs_sync(
connector_id: int, credential_id: int
) -> Select:
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
)
stmt = (
select(DbDocument)
.where(
DbDocument.id.in_(initial_doc_ids_stmt),
or_(
DbDocument.last_modified
> DbDocument.last_synced, # last_modified is newer than last_synced
DbDocument.last_synced.is_(None), # never synced
),
)
.distinct()
)
return stmt
def construct_document_select_for_connector_credential_pair(
connector_id: int, credential_id: int | None = None
) -> Select:
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
)
stmt = select(DbDocument).where(DbDocument.id.in_(initial_doc_ids_stmt)).distinct()
return stmt
def get_documents_for_connector_credential_pair(
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
) -> Sequence[DbDocument]:
@@ -62,7 +128,18 @@ def get_documents_by_ids(
return list(documents)
def get_document_connector_cnts(
def get_document_connector_count(
db_session: Session,
document_id: str,
) -> int:
results = get_document_connector_counts(db_session, [document_id])
if not results or len(results) == 0:
return 0
return results[0][1]
def get_document_connector_counts(
db_session: Session,
document_ids: list[str],
) -> Sequence[tuple[str, int]]:
@@ -77,7 +154,7 @@ def get_document_connector_cnts(
return db_session.execute(stmt).all() # type: ignore
def get_document_cnts_for_cc_pairs(
def get_document_counts_for_cc_pairs(
db_session: Session, cc_pair_identifiers: list[ConnectorCredentialPairIdentifier]
) -> Sequence[tuple[int, int, int]]:
stmt = (
@@ -108,22 +185,50 @@ def get_document_cnts_for_cc_pairs(
return db_session.execute(stmt).all() # type: ignore
def get_acccess_info_for_documents(
def get_access_info_for_document(
db_session: Session,
document_id: str,
) -> tuple[str, list[str | None], bool] | None:
"""Gets access info for a single document by calling the get_access_info_for_documents function
and passing a list with a single document ID.
Args:
db_session (Session): The database session to use.
document_id (str): The document ID to fetch access info for.
Returns:
Optional[Tuple[str, List[str | None], bool]]: A tuple containing the document ID, a list of user emails,
and a boolean indicating if the document is globally public, or None if no results are found.
"""
results = get_access_info_for_documents(db_session, [document_id])
if not results:
return None
return results[0]
def get_access_info_for_documents(
db_session: Session,
document_ids: list[str],
) -> Sequence[tuple[str, list[UUID | None], bool]]:
) -> Sequence[tuple[str, list[str | None], bool]]:
"""Gets back all relevant access info for the given documents. This includes
the user_ids for cc pairs that the document is associated with + whether any
of the associated cc pairs are intending to make the document globally public.
Returns the list where each element contains:
- Document ID (which is also the ID of the DocumentByConnectorCredentialPair)
- List of emails of Danswer users with direct access to the doc (includes a "None" element if
the connector was set up by an admin when auth was off
- bool for whether the document is public (the document later can also be marked public by
automatic permission sync step)
"""
stmt = select(
DocumentByConnectorCredentialPair.id,
func.array_agg(func.coalesce(User.email, null())).label("user_emails"),
func.bool_or(ConnectorCredentialPair.access_type == AccessType.PUBLIC).label(
"public_doc"
),
).where(DocumentByConnectorCredentialPair.id.in_(document_ids))
stmt = (
select(
DocumentByConnectorCredentialPair.id,
func.array_agg(Credential.user_id).label("user_ids"),
func.bool_or(ConnectorCredentialPair.is_public).label("public_doc"),
)
.where(DocumentByConnectorCredentialPair.id.in_(document_ids))
.join(
stmt.join(
Credential,
DocumentByConnectorCredentialPair.credential_id == Credential.id,
)
@@ -136,6 +241,13 @@ def get_acccess_info_for_documents(
== ConnectorCredentialPair.credential_id,
),
)
.outerjoin(
User,
and_(
Credential.user_id == User.id,
ConnectorCredentialPair.access_type != AccessType.SYNC,
),
)
# don't include CC pairs that are being deleted
# NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them
.where(ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING)
@@ -173,6 +285,7 @@ def upsert_documents(
semantic_id=doc.semantic_identifier,
link=doc.first_link,
doc_updated_at=None, # this is intentional
last_modified=datetime.now(timezone.utc),
primary_owners=doc.primary_owners,
secondary_owners=doc.secondary_owners,
)
@@ -180,9 +293,19 @@ def upsert_documents(
for doc in seen_documents.values()
]
)
# for now, there are no columns to update. If more metadata is added, then this
# needs to change to an `on_conflict_do_update`
on_conflict_stmt = insert_stmt.on_conflict_do_nothing()
on_conflict_stmt = insert_stmt.on_conflict_do_update(
index_elements=["id"], # Conflict target
set_={
"from_ingestion_api": insert_stmt.excluded.from_ingestion_api,
"boost": insert_stmt.excluded.boost,
"hidden": insert_stmt.excluded.hidden,
"semantic_id": insert_stmt.excluded.semantic_id,
"link": insert_stmt.excluded.link,
"primary_owners": insert_stmt.excluded.primary_owners,
"secondary_owners": insert_stmt.excluded.secondary_owners,
},
)
db_session.execute(on_conflict_stmt)
db_session.commit()
@@ -214,7 +337,7 @@ def upsert_document_by_connector_credential_pair(
db_session.commit()
def update_docs_updated_at(
def update_docs_updated_at__no_commit(
ids_to_new_updated_at: dict[str, datetime],
db_session: Session,
) -> None:
@@ -226,6 +349,28 @@ def update_docs_updated_at(
for document in documents_to_update:
document.doc_updated_at = ids_to_new_updated_at[document.id]
def update_docs_last_modified__no_commit(
document_ids: list[str],
db_session: Session,
) -> None:
documents_to_update = (
db_session.query(DbDocument).filter(DbDocument.id.in_(document_ids)).all()
)
now = datetime.now(timezone.utc)
for doc in documents_to_update:
doc.last_modified = now
def mark_document_as_synced(document_id: str, db_session: Session) -> None:
stmt = select(DbDocument).where(DbDocument.id == document_id)
doc = db_session.scalar(stmt)
if doc is None:
raise ValueError(f"No document with ID: {document_id}")
# update last_synced
doc.last_synced = datetime.now(timezone.utc)
db_session.commit()
@@ -241,11 +386,34 @@ def upsert_documents_complete(
def delete_document_by_connector_credential_pair__no_commit(
db_session: Session,
document_id: str,
connector_credential_pair_identifier: ConnectorCredentialPairIdentifier
| None = None,
) -> None:
"""Deletes a single document by cc pair relationship entry.
Foreign key rows are left in place.
The implicit assumption is that the document itself still has other cc_pair
references and needs to continue existing.
"""
delete_documents_by_connector_credential_pair__no_commit(
db_session=db_session,
document_ids=[document_id],
connector_credential_pair_identifier=connector_credential_pair_identifier,
)
def delete_documents_by_connector_credential_pair__no_commit(
db_session: Session,
document_ids: list[str],
connector_credential_pair_identifier: ConnectorCredentialPairIdentifier
| None = None,
) -> None:
"""This deletes just the document by cc pair entries for a particular cc pair.
Foreign key rows are left in place.
The implicit assumption is that the document itself still has other cc_pair
references and needs to continue existing.
"""
stmt = delete(DocumentByConnectorCredentialPair).where(
DocumentByConnectorCredentialPair.id.in_(document_ids)
)
@@ -268,8 +436,9 @@ def delete_documents__no_commit(db_session: Session, document_ids: list[str]) ->
def delete_documents_complete__no_commit(
db_session: Session, document_ids: list[str]
) -> None:
"""This completely deletes the documents from the db, including all foreign key relationships"""
logger.info(f"Deleting {len(document_ids)} documents from the DB")
delete_document_by_connector_credential_pair__no_commit(db_session, document_ids)
delete_documents_by_connector_credential_pair__no_commit(db_session, document_ids)
delete_document_feedback_for_documents__no_commit(
document_ids=document_ids, db_session=db_session
)
@@ -379,3 +548,12 @@ def get_documents_by_cc_pair(
.filter(ConnectorCredentialPair.id == cc_pair_id)
.all()
)
def get_document(
document_id: str,
db_session: Session,
) -> DbDocument | None:
stmt = select(DbDocument).where(DbDocument.id == document_id)
doc: DbDocument | None = db_session.execute(stmt).scalar_one_or_none()
return doc

View File

@@ -14,6 +14,7 @@ from sqlalchemy.orm import Session
from danswer.db.connector_credential_pair import get_cc_pair_groups_for_ids
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Document
@@ -180,7 +181,7 @@ def _check_if_cc_pairs_are_owned_by_groups(
ids=missing_cc_pair_ids,
)
for cc_pair in cc_pairs:
if not cc_pair.is_public:
if cc_pair.access_type != AccessType.PUBLIC:
raise ValueError(
f"Connector Credential Pair with ID: '{cc_pair.id}'"
" is not owned by the specified groups"
@@ -248,6 +249,10 @@ def update_document_set(
document_set_update_request: DocumentSetUpdateRequest,
user: User | None = None,
) -> tuple[DocumentSetDBModel, list[DocumentSet__ConnectorCredentialPair]]:
"""If successful, this sets document_set_row.is_up_to_date = False.
That will be processed via Celery in check_for_vespa_sync_task
and trigger a long running background sync to Vespa.
"""
if not document_set_update_request.cc_pair_ids:
# It's cc-pairs in actuality but the UI displays this error
raise ValueError("Cannot create a document set with no Connectors")
@@ -519,11 +524,101 @@ def fetch_documents_for_document_set_paginated(
return documents, documents[-1].id if documents else None
def construct_document_select_by_docset(
document_set_id: int,
current_only: bool = True,
) -> Select:
"""This returns a statement that should be executed using
.yield_per() to minimize overhead. The primary consumers of this function
are background processing task generators."""
stmt = (
select(Document)
.join(
DocumentByConnectorCredentialPair,
DocumentByConnectorCredentialPair.id == Document.id,
)
.join(
ConnectorCredentialPair,
and_(
ConnectorCredentialPair.connector_id
== DocumentByConnectorCredentialPair.connector_id,
ConnectorCredentialPair.credential_id
== DocumentByConnectorCredentialPair.credential_id,
),
)
.join(
DocumentSet__ConnectorCredentialPair,
DocumentSet__ConnectorCredentialPair.connector_credential_pair_id
== ConnectorCredentialPair.id,
)
.join(
DocumentSetDBModel,
DocumentSetDBModel.id
== DocumentSet__ConnectorCredentialPair.document_set_id,
)
.where(DocumentSetDBModel.id == document_set_id)
.order_by(Document.id)
)
if current_only:
stmt = stmt.where(
DocumentSet__ConnectorCredentialPair.is_current == True # noqa: E712
)
stmt = stmt.distinct()
return stmt
def fetch_document_sets_for_document(
document_id: str,
db_session: Session,
) -> list[str]:
"""
Fetches the document set names for a single document ID.
:param document_id: The ID of the document to fetch sets for.
:param db_session: The SQLAlchemy session to use for the query.
:return: A list of document set names, or None if no result is found.
"""
result = fetch_document_sets_for_documents([document_id], db_session)
if not result:
return []
return result[0][1]
def fetch_document_sets_for_documents(
document_ids: list[str],
db_session: Session,
) -> Sequence[tuple[str, list[str]]]:
"""Gives back a list of (document_id, list[document_set_names]) tuples"""
"""Building subqueries"""
# NOTE: have to build these subqueries first in order to guarantee that we get one
# returned row for each specified document_id. Basically, we want to do the filters first,
# then the outer joins.
# don't include CC pairs that are being deleted
# NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them
# as we can assume their document sets are no longer relevant
valid_cc_pairs_subquery = aliased(
ConnectorCredentialPair,
select(ConnectorCredentialPair)
.where(
ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING
) # noqa: E712
.subquery(),
)
valid_document_set__cc_pairs_subquery = aliased(
DocumentSet__ConnectorCredentialPair,
select(DocumentSet__ConnectorCredentialPair)
.where(DocumentSet__ConnectorCredentialPair.is_current == True) # noqa: E712
.subquery(),
)
"""End building subqueries"""
stmt = (
select(
Document.id,
@@ -531,39 +626,33 @@ def fetch_document_sets_for_documents(
func.array_remove(func.array_agg(DocumentSetDBModel.name), None), []
).label("document_set_names"),
)
# Here we select document sets by relation:
# Document -> DocumentByConnectorCredentialPair -> ConnectorCredentialPair ->
# DocumentSet__ConnectorCredentialPair -> DocumentSet
.outerjoin(
DocumentByConnectorCredentialPair,
Document.id == DocumentByConnectorCredentialPair.id,
)
.outerjoin(
ConnectorCredentialPair,
valid_cc_pairs_subquery,
and_(
DocumentByConnectorCredentialPair.connector_id
== ConnectorCredentialPair.connector_id,
== valid_cc_pairs_subquery.connector_id,
DocumentByConnectorCredentialPair.credential_id
== ConnectorCredentialPair.credential_id,
== valid_cc_pairs_subquery.credential_id,
),
)
.outerjoin(
DocumentSet__ConnectorCredentialPair,
ConnectorCredentialPair.id
== DocumentSet__ConnectorCredentialPair.connector_credential_pair_id,
valid_document_set__cc_pairs_subquery,
valid_cc_pairs_subquery.id
== valid_document_set__cc_pairs_subquery.connector_credential_pair_id,
)
.outerjoin(
DocumentSetDBModel,
DocumentSetDBModel.id
== DocumentSet__ConnectorCredentialPair.document_set_id,
== valid_document_set__cc_pairs_subquery.document_set_id,
)
.where(Document.id.in_(document_ids))
# don't include CC pairs that are being deleted
# NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them
# as we can assume their document sets are no longer relevant
.where(
ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING,
)
.where(
DocumentSet__ConnectorCredentialPair.is_current == True, # noqa: E712
)
.group_by(Document.id)
)
return db_session.execute(stmt).all() # type: ignore
@@ -616,7 +705,7 @@ def check_document_sets_are_public(
ConnectorCredentialPair.id.in_(
connector_credential_pair_ids # type:ignore
),
ConnectorCredentialPair.is_public.is_(False),
ConnectorCredentialPair.access_type != AccessType.PUBLIC,
)
.limit(1)
.first()

View File

@@ -137,8 +137,8 @@ def get_sqlalchemy_engine() -> Engine:
)
_SYNC_ENGINE = create_engine(
connection_string,
pool_size=40,
max_overflow=10,
pool_size=5,
max_overflow=0,
pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE,
)
@@ -156,8 +156,8 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
connect_args={
"server_settings": {"application_name": POSTGRES_APP_NAME + "_async"}
},
pool_size=40,
max_overflow=10,
pool_size=5,
max_overflow=0,
pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE,
)

View File

@@ -51,3 +51,9 @@ class ConnectorCredentialPairStatus(str, PyEnum):
def is_active(self) -> bool:
return self == ConnectorCredentialPairStatus.ACTIVE
class AccessType(str, PyEnum):
PUBLIC = "public"
PRIVATE = "private"
SYNC = "sync"

View File

@@ -1,3 +1,5 @@
from datetime import datetime
from datetime import timezone
from uuid import UUID
from fastapi import HTTPException
@@ -14,6 +16,7 @@ from sqlalchemy.orm import Session
from danswer.configs.constants import MessageType
from danswer.configs.constants import SearchFeedbackType
from danswer.db.chat import get_chat_message
from danswer.db.enums import AccessType
from danswer.db.models import ChatMessageFeedback
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Document as DbDocument
@@ -24,7 +27,6 @@ from danswer.db.models import User__UserGroup
from danswer.db.models import UserGroup__ConnectorCredentialPair
from danswer.db.models import UserRole
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import UpdateRequest
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -93,7 +95,7 @@ def _add_user_filters(
.correlate(CCPair)
)
else:
where_clause |= CCPair.is_public == True # noqa: E712
where_clause |= CCPair.access_type == AccessType.PUBLIC
return stmt.where(where_clause)
@@ -123,12 +125,11 @@ def update_document_boost(
db_session: Session,
document_id: str,
boost: int,
document_index: DocumentIndex,
user: User | None = None,
) -> None:
stmt = select(DbDocument).where(DbDocument.id == document_id)
stmt = _add_user_filters(stmt, user, get_editable=True)
result = db_session.execute(stmt).scalar_one_or_none()
result: DbDocument | None = db_session.execute(stmt).scalar_one_or_none()
if result is None:
raise HTTPException(
status_code=400, detail="Document is not editable by this user"
@@ -136,13 +137,9 @@ def update_document_boost(
result.boost = boost
update = UpdateRequest(
document_ids=[document_id],
boost=boost,
)
document_index.update(update_requests=[update])
# updating last_modified triggers sync
# TODO: Should this submit to the queue directly so that the UI can update?
result.last_modified = datetime.now(timezone.utc)
db_session.commit()
@@ -163,13 +160,9 @@ def update_document_hidden(
result.hidden = hidden
update = UpdateRequest(
document_ids=[document_id],
hidden=hidden,
)
document_index.update(update_requests=[update])
# updating last_modified triggers sync
# TODO: Should this submit to the queue directly so that the UI can update?
result.last_modified = datetime.now(timezone.utc)
db_session.commit()
@@ -210,11 +203,9 @@ def create_doc_retrieval_feedback(
SearchFeedbackType.REJECT,
SearchFeedbackType.HIDE,
]:
update = UpdateRequest(
document_ids=[document_id], boost=db_doc.boost, hidden=db_doc.hidden
)
# Updates are generally batched for efficiency, this case only 1 doc/value is updated
document_index.update(update_requests=[update])
# updating last_modified triggers sync
# TODO: Should this submit to the queue directly so that the UI can update?
db_doc.last_modified = datetime.now(timezone.utc)
db_session.add(retrieval_feedback)
db_session.commit()

View File

@@ -181,6 +181,45 @@ def get_last_attempt(
return db_session.execute(stmt).scalars().first()
def get_latest_index_attempts_by_status(
secondary_index: bool,
db_session: Session,
status: IndexingStatus,
) -> Sequence[IndexAttempt]:
"""
Retrieves the most recent index attempt with the specified status for each connector_credential_pair.
Filters attempts based on the secondary_index flag to get either future or present index attempts.
Returns a sequence of IndexAttempt objects, one for each unique connector_credential_pair.
"""
latest_failed_attempts = (
select(
IndexAttempt.connector_credential_pair_id,
func.max(IndexAttempt.id).label("max_failed_id"),
)
.join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
.where(
SearchSettings.status
== (
IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
),
IndexAttempt.status == status,
)
.group_by(IndexAttempt.connector_credential_pair_id)
.subquery()
)
stmt = select(IndexAttempt).join(
latest_failed_attempts,
(
IndexAttempt.connector_credential_pair_id
== latest_failed_attempts.c.connector_credential_pair_id
)
& (IndexAttempt.id == latest_failed_attempts.c.max_failed_id),
)
return db_session.execute(stmt).scalars().all()
def get_latest_index_attempts(
secondary_index: bool,
db_session: Session,
@@ -211,12 +250,41 @@ def get_latest_index_attempts(
return db_session.execute(stmt).scalars().all()
def get_index_attempts_for_connector(
def count_index_attempts_for_connector(
db_session: Session,
connector_id: int,
only_current: bool = True,
disinclude_finished: bool = False,
) -> Sequence[IndexAttempt]:
) -> int:
stmt = (
select(IndexAttempt)
.join(ConnectorCredentialPair)
.where(ConnectorCredentialPair.connector_id == connector_id)
)
if disinclude_finished:
stmt = stmt.where(
IndexAttempt.status.in_(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
)
)
if only_current:
stmt = stmt.join(SearchSettings).where(
SearchSettings.status == IndexModelStatus.PRESENT
)
# Count total items for pagination
count_stmt = stmt.with_only_columns(func.count()).order_by(None)
total_count = db_session.execute(count_stmt).scalar_one()
return total_count
def get_paginated_index_attempts_for_cc_pair_id(
db_session: Session,
connector_id: int,
page: int,
page_size: int,
only_current: bool = True,
disinclude_finished: bool = False,
) -> list[IndexAttempt]:
stmt = (
select(IndexAttempt)
.join(ConnectorCredentialPair)
@@ -233,22 +301,30 @@ def get_index_attempts_for_connector(
SearchSettings.status == IndexModelStatus.PRESENT
)
stmt = stmt.order_by(IndexAttempt.time_created.desc())
return db_session.execute(stmt).scalars().all()
stmt = stmt.order_by(IndexAttempt.time_started.desc())
# Apply pagination
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
return list(db_session.execute(stmt).scalars().all())
def get_latest_finished_index_attempt_for_cc_pair(
def get_latest_index_attempt_for_cc_pair_id(
db_session: Session,
connector_credential_pair_id: int,
secondary_index: bool,
db_session: Session,
only_finished: bool = True,
) -> IndexAttempt | None:
stmt = select(IndexAttempt).distinct()
stmt = select(IndexAttempt)
stmt = stmt.where(
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
IndexAttempt.status.not_in(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
),
)
if only_finished:
stmt = stmt.where(
IndexAttempt.status.not_in(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
),
)
if secondary_index:
stmt = stmt.join(SearchSettings).where(
SearchSettings.status == IndexModelStatus.FUTURE
@@ -295,14 +371,21 @@ def get_index_attempts_for_cc_pair(
def delete_index_attempts(
connector_id: int,
credential_id: int,
cc_pair_id: int,
db_session: Session,
) -> None:
# First, delete related entries in IndexAttemptErrors
stmt_errors = delete(IndexAttemptError).where(
IndexAttemptError.index_attempt_id.in_(
select(IndexAttempt.id).where(
IndexAttempt.connector_credential_pair_id == cc_pair_id
)
)
)
db_session.execute(stmt_errors)
stmt = delete(IndexAttempt).where(
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
IndexAttempt.connector_credential_pair_id == cc_pair_id,
)
db_session.execute(stmt)

View File

@@ -4,8 +4,11 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from danswer.db.models import DocumentSet
from danswer.db.models import LLMProvider as LLMProviderModel
from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import SearchSettings
from danswer.db.models import Tool as ToolModel
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
@@ -50,6 +53,7 @@ def upsert_cloud_embedding_provider(
setattr(existing_provider, key, value)
else:
new_provider = CloudEmbeddingProviderModel(**provider.model_dump())
db_session.add(new_provider)
existing_provider = new_provider
db_session.commit()
@@ -58,13 +62,21 @@ def upsert_cloud_embedding_provider(
def upsert_llm_provider(
db_session: Session, llm_provider: LLMProviderUpsertRequest
llm_provider: LLMProviderUpsertRequest,
db_session: Session,
is_creation: bool = True,
) -> FullLLMProvider:
existing_llm_provider = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
)
if existing_llm_provider and is_creation:
raise ValueError(f"LLM Provider with name {llm_provider.name} already exists")
if not existing_llm_provider:
if not is_creation:
raise ValueError(
f"LLM Provider with name {llm_provider.name} does not exist"
)
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
db_session.add(existing_llm_provider)
@@ -101,6 +113,20 @@ def fetch_existing_embedding_providers(
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
def fetch_existing_doc_sets(
db_session: Session, doc_ids: list[int]
) -> list[DocumentSet]:
return list(
db_session.scalars(select(DocumentSet).where(DocumentSet.id.in_(doc_ids))).all()
)
def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolModel]:
return list(
db_session.scalars(select(ToolModel).where(ToolModel.id.in_(tool_ids))).all()
)
def fetch_existing_llm_providers(
db_session: Session,
user: User | None = None,
@@ -157,12 +183,19 @@ def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider |
def remove_embedding_provider(
db_session: Session, provider_type: EmbeddingProvider
) -> None:
db_session.execute(
delete(SearchSettings).where(SearchSettings.provider_type == provider_type)
)
# Delete the embedding provider
db_session.execute(
delete(CloudEmbeddingProviderModel).where(
CloudEmbeddingProviderModel.provider_type == provider_type
)
)
db_session.commit()
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
# Remove LLMProvider's dependent relationships
@@ -178,7 +211,7 @@ def remove_llm_provider(db_session: Session, provider_id: int) -> None:
db_session.commit()
def update_default_provider(db_session: Session, provider_id: int) -> None:
def update_default_provider(provider_id: int, db_session: Session) -> None:
new_default = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.id == provider_id)
)

View File

@@ -39,6 +39,7 @@ from danswer.configs.constants import DEFAULT_BOOST
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import MessageType
from danswer.db.enums import AccessType
from danswer.configs.constants import NotificationType
from danswer.configs.constants import SearchFeedbackType
from danswer.configs.constants import TokenRateLimitScope
@@ -61,7 +62,7 @@ from shared_configs.enums import RerankerProvider
class Base(DeclarativeBase):
pass
__abstract__ = True
class EncryptedString(TypeDecorator):
@@ -108,7 +109,7 @@ class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
class User(SQLAlchemyBaseUserTableUUID, Base):
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
"OAuthAccount", lazy="joined"
"OAuthAccount", lazy="joined", cascade="all, delete-orphan"
)
role: Mapped[UserRole] = mapped_column(
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
@@ -122,7 +123,13 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# if specified, controls the assistants that are shown to the user + their order
# if not specified, all assistants are shown
chosen_assistants: Mapped[list[int]] = mapped_column(
postgresql.JSONB(), nullable=True
postgresql.JSONB(), nullable=False, default=[-2, -1, 0]
)
visible_assistants: Mapped[list[int]] = mapped_column(
postgresql.JSONB(), nullable=False, default=[]
)
hidden_assistants: Mapped[list[int]] = mapped_column(
postgresql.JSONB(), nullable=False, default=[]
)
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
@@ -157,6 +164,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
notifications: Mapped[list["Notification"]] = relationship(
"Notification", back_populates="user"
)
# Whether the user has logged in via web. False if user has only used Danswer through Slack bot
has_web_login: Mapped[bool] = mapped_column(Boolean, default=True)
class InputPrompt(Base):
@@ -168,7 +177,9 @@ class InputPrompt(Base):
active: Mapped[bool] = mapped_column(Boolean)
user: Mapped[User | None] = relationship("User", back_populates="input_prompts")
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
class InputPrompt__User(Base):
@@ -212,7 +223,9 @@ class Notification(Base):
notif_type: Mapped[NotificationType] = mapped_column(
Enum(NotificationType, native_enum=False)
)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
dismissed: Mapped[bool] = mapped_column(Boolean, default=False)
last_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
first_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
@@ -247,7 +260,7 @@ class Persona__User(Base):
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id"), primary_key=True, nullable=True
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True, nullable=True
)
@@ -258,7 +271,7 @@ class DocumentSet__User(Base):
ForeignKey("document_set.id"), primary_key=True
)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id"), primary_key=True, nullable=True
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True, nullable=True
)
@@ -373,16 +386,29 @@ class ConnectorCredentialPair(Base):
connector_id: Mapped[int] = mapped_column(
ForeignKey("connector.id"), primary_key=True
)
deletion_failure_message: Mapped[str | None] = mapped_column(String, nullable=True)
credential_id: Mapped[int] = mapped_column(
ForeignKey("credential.id"), primary_key=True
)
# controls whether the documents indexed by this CC pair are visible to all
# or if they are only visible to those with that are given explicit access
# (e.g. via owning the credential or being a part of a group that is given access)
is_public: Mapped[bool] = mapped_column(
Boolean,
default=True,
nullable=False,
access_type: Mapped[AccessType] = mapped_column(
Enum(AccessType, native_enum=False), nullable=False
)
# special info needed for the auto-sync feature. The exact structure depends on the
# source type (defined in the connector's `source` field)
# E.g. for google_drive perm sync:
# {"customer_id": "123567", "company_domain": "@danswer.ai"}
auto_sync_options: Mapped[dict[str, Any] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
last_time_perm_sync: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
# Time finished, not used for calculating backend jobs which uses time started (created)
last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column(
@@ -413,6 +439,7 @@ class ConnectorCredentialPair(Base):
class Document(Base):
__tablename__ = "document"
# NOTE: if more sensitive data is added here for display, make sure to add user/group permission
# this should correspond to the ID of the document
# (as is passed around in Danswer)
@@ -426,12 +453,27 @@ class Document(Base):
semantic_id: Mapped[str] = mapped_column(String)
# First Section's link
link: Mapped[str | None] = mapped_column(String, nullable=True)
# The updated time is also used as a measure of the last successful state of the doc
# pulled from the source (to help skip reindexing already updated docs in case of
# connector retries)
# TODO: rename this column because it conflates the time of the source doc
# with the local last modified time of the doc and any associated metadata
# it should just be the server timestamp of the source doc
doc_updated_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
# last time any vespa relevant row metadata or the doc changed.
# does not include last_synced
last_modified: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=False, index=True, default=func.now()
)
# last successful sync to vespa
last_synced: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True, index=True
)
# The following are not attached to User because the account/email may not be known
# within Danswer
# Something like the document creator
@@ -441,14 +483,25 @@ class Document(Base):
secondary_owners: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
)
# TODO if more sensitive data is added here for display, make sure to add user/group permission
# Permission sync columns
# Email addresses are saved at the document level for externally synced permissions
# This is becuase the normal flow of assigning permissions is through the cc_pair
# doesn't apply here
external_user_emails: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
)
# These group ids have been prefixed by the source type
external_user_group_ids: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
)
is_public: Mapped[bool] = mapped_column(Boolean, default=False)
retrieval_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship(
"DocumentRetrievalFeedback", back_populates="document"
)
tags = relationship(
"Tag",
secondary="document__tag",
secondary=Document__Tag.__table__,
back_populates="documents",
)
@@ -465,7 +518,7 @@ class Tag(Base):
documents = relationship(
"Document",
secondary="document__tag",
secondary=Document__Tag.__table__,
back_populates="tags",
)
@@ -521,7 +574,9 @@ class Credential(Base):
id: Mapped[int] = mapped_column(primary_key=True)
credential_json: Mapped[dict[str, Any]] = mapped_column(EncryptedJson())
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
# if `true`, then all Admins will have access to the credential
admin_public: Mapped[bool] = mapped_column(Boolean, default=True)
time_created: Mapped[datetime.datetime] = mapped_column(
@@ -576,6 +631,8 @@ class SearchSettings(Base):
Enum(RerankerProvider, native_enum=False), nullable=True
)
rerank_api_key: Mapped[str | None] = mapped_column(String, nullable=True)
rerank_api_url: Mapped[str | None] = mapped_column(String, nullable=True)
num_rerank: Mapped[int] = mapped_column(Integer, default=NUM_POSTPROCESSED_RESULTS)
cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship(
@@ -607,6 +664,10 @@ class SearchSettings(Base):
return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\
cloud_provider='{self.cloud_provider.provider_type if self.cloud_provider else 'None'}')>"
@property
def api_url(self) -> str | None:
return self.cloud_provider.api_url if self.cloud_provider is not None else None
@property
def api_key(self) -> str | None:
return self.cloud_provider.api_key if self.cloud_provider is not None else None
@@ -671,7 +732,11 @@ class IndexAttempt(Base):
"SearchSettings", back_populates="index_attempts"
)
error_rows = relationship("IndexAttemptError", back_populates="index_attempt")
error_rows = relationship(
"IndexAttemptError",
back_populates="index_attempt",
cascade="all, delete-orphan",
)
__table_args__ = (
Index(
@@ -806,7 +871,7 @@ class SearchDoc(Base):
chat_messages = relationship(
"ChatMessage",
secondary="chat_message__search_doc",
secondary=ChatMessage__SearchDoc.__table__,
back_populates="search_docs",
)
@@ -835,8 +900,12 @@ class ChatSession(Base):
__tablename__ = "chat_session"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"))
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id"), nullable=True
)
description: Mapped[str] = mapped_column(Text)
# One-shot direct answering, currently the two types of chats are not mixed
one_shot: Mapped[bool] = mapped_column(Boolean, default=False)
@@ -870,7 +939,6 @@ class ChatSession(Base):
prompt_override: Mapped[PromptOverride | None] = mapped_column(
PydanticType(PromptOverride), nullable=True
)
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
@@ -879,7 +947,6 @@ class ChatSession(Base):
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
user: Mapped[User] = relationship("User", back_populates="chat_sessions")
folder: Mapped["ChatFolder"] = relationship(
"ChatFolder", back_populates="chat_sessions"
@@ -949,7 +1016,7 @@ class ChatMessage(Base):
)
search_docs: Mapped[list["SearchDoc"]] = relationship(
"SearchDoc",
secondary="chat_message__search_doc",
secondary=ChatMessage__SearchDoc.__table__,
back_populates="chat_messages",
)
# NOTE: Should always be attached to the `assistant` message.
@@ -972,7 +1039,9 @@ class ChatFolder(Base):
id: Mapped[int] = mapped_column(primary_key=True)
# Only null if auth is off
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
name: Mapped[str | None] = mapped_column(String, nullable=True)
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=0)
@@ -1085,6 +1154,7 @@ class CloudEmbeddingProvider(Base):
provider_type: Mapped[EmbeddingProvider] = mapped_column(
Enum(EmbeddingProvider), primary_key=True
)
api_url: Mapped[str | None] = mapped_column(String, nullable=True)
api_key: Mapped[str | None] = mapped_column(EncryptedString())
search_settings: Mapped[list["SearchSettings"]] = relationship(
"SearchSettings",
@@ -1102,7 +1172,9 @@ class DocumentSet(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
description: Mapped[str] = mapped_column(String)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
# Whether changes to the document set have been propagated
is_up_to_date: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
# If `False`, then the document set is not visible to users who are not explicitly
@@ -1146,7 +1218,9 @@ class Prompt(Base):
__tablename__ = "prompt"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
name: Mapped[str] = mapped_column(String)
description: Mapped[str] = mapped_column(String)
system_prompt: Mapped[str] = mapped_column(Text)
@@ -1181,9 +1255,13 @@ class Tool(Base):
openapi_schema: Mapped[dict[str, Any] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
custom_headers: Mapped[list[dict[str, str]] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
# user who created / owns the tool. Will be None for built-in tools.
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
user: Mapped[User | None] = relationship("User", back_populates="custom_tools")
# Relationship to Persona through the association table
@@ -1207,7 +1285,9 @@ class Persona(Base):
__tablename__ = "persona"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
name: Mapped[str] = mapped_column(String)
description: Mapped[str] = mapped_column(String)
# Number of chunks to pass to the LLM for generation.
@@ -1236,9 +1316,18 @@ class Persona(Base):
starter_messages: Mapped[list[StarterMessage] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
# Default personas are configured via backend during deployment
search_start_date: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), default=None
)
# Built-in personas are configured via backend during deployment
# Treated specially (cannot be user edited etc.)
default_persona: Mapped[bool] = mapped_column(Boolean, default=False)
builtin_persona: Mapped[bool] = mapped_column(Boolean, default=False)
# Default personas are personas created by admins and are automatically added
# to all users' assistants list.
is_default_persona: Mapped[bool] = mapped_column(
Boolean, default=False, nullable=False
)
# controls whether the persona is available to be selected by users
is_visible: Mapped[bool] = mapped_column(Boolean, default=True)
# controls the ordering of personas in the UI
@@ -1289,10 +1378,10 @@ class Persona(Base):
# Default personas loaded via yaml cannot have the same name
__table_args__ = (
Index(
"_default_persona_name_idx",
"_builtin_persona_name_idx",
"name",
unique=True,
postgresql_where=(default_persona == True), # noqa: E712
postgresql_where=(builtin_persona == True), # noqa: E712
),
)
@@ -1316,53 +1405,6 @@ class ChannelConfig(TypedDict):
follow_up_tags: NotRequired[list[str]]
class StandardAnswerCategory(Base):
__tablename__ = "standard_answer_category"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
"StandardAnswer",
secondary=StandardAnswer__StandardAnswerCategory.__table__,
back_populates="categories",
)
slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship(
"SlackBotConfig",
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
back_populates="standard_answer_categories",
)
class StandardAnswer(Base):
__tablename__ = "standard_answer"
id: Mapped[int] = mapped_column(primary_key=True)
keyword: Mapped[str] = mapped_column(String)
answer: Mapped[str] = mapped_column(String)
active: Mapped[bool] = mapped_column(Boolean)
__table_args__ = (
Index(
"unique_keyword_active",
keyword,
active,
unique=True,
postgresql_where=(active == True), # noqa: E712
),
)
categories: Mapped[list[StandardAnswerCategory]] = relationship(
"StandardAnswerCategory",
secondary=StandardAnswer__StandardAnswerCategory.__table__,
back_populates="standard_answers",
)
chat_messages: Mapped[list[ChatMessage]] = relationship(
"ChatMessage",
secondary=ChatMessage__StandardAnswer.__table__,
back_populates="standard_answers",
)
class SlackBotResponseType(str, PyEnum):
QUOTES = "quotes"
CITATIONS = "citations"
@@ -1388,7 +1430,7 @@ class SlackBotConfig(Base):
)
persona: Mapped[Persona | None] = relationship("Persona")
standard_answer_categories: Mapped[list[StandardAnswerCategory]] = relationship(
standard_answer_categories: Mapped[list["StandardAnswerCategory"]] = relationship(
"StandardAnswerCategory",
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
back_populates="slack_bot_configs",
@@ -1400,7 +1442,7 @@ class TaskQueueState(Base):
__tablename__ = "task_queue_jobs"
id: Mapped[int] = mapped_column(primary_key=True)
# Celery task id
# Celery task id. currently only for readability/diagnostics
task_id: Mapped[str] = mapped_column(String)
# For any job type, this would be the same
task_name: Mapped[str] = mapped_column(String)
@@ -1450,7 +1492,9 @@ class SamlAccount(Base):
__tablename__ = "saml"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"), unique=True)
user_id: Mapped[int] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), unique=True
)
encrypted_cookie: Mapped[str] = mapped_column(Text, unique=True)
expires_at: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
updated_at: Mapped[datetime.datetime] = mapped_column(
@@ -1469,7 +1513,7 @@ class User__UserGroup(Base):
ForeignKey("user_group.id"), primary_key=True
)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id"), primary_key=True, nullable=True
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True, nullable=True
)
@@ -1618,94 +1662,70 @@ class TokenRateLimit__UserGroup(Base):
)
class StandardAnswerCategory(Base):
__tablename__ = "standard_answer_category"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
"StandardAnswer",
secondary=StandardAnswer__StandardAnswerCategory.__table__,
back_populates="categories",
)
slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship(
"SlackBotConfig",
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
back_populates="standard_answer_categories",
)
class StandardAnswer(Base):
__tablename__ = "standard_answer"
id: Mapped[int] = mapped_column(primary_key=True)
keyword: Mapped[str] = mapped_column(String)
answer: Mapped[str] = mapped_column(String)
active: Mapped[bool] = mapped_column(Boolean)
match_regex: Mapped[bool] = mapped_column(Boolean)
match_any_keywords: Mapped[bool] = mapped_column(Boolean)
__table_args__ = (
Index(
"unique_keyword_active",
keyword,
active,
unique=True,
postgresql_where=(active == True), # noqa: E712
),
)
categories: Mapped[list[StandardAnswerCategory]] = relationship(
"StandardAnswerCategory",
secondary=StandardAnswer__StandardAnswerCategory.__table__,
back_populates="standard_answers",
)
chat_messages: Mapped[list[ChatMessage]] = relationship(
"ChatMessage",
secondary=ChatMessage__StandardAnswer.__table__,
back_populates="standard_answers",
)
"""Tables related to Permission Sync"""
class PermissionSyncStatus(str, PyEnum):
IN_PROGRESS = "in_progress"
SUCCESS = "success"
FAILED = "failed"
class PermissionSyncJobType(str, PyEnum):
USER_LEVEL = "user_level"
GROUP_LEVEL = "group_level"
class PermissionSyncRun(Base):
"""Represents one run of a permission sync job. For some given cc_pair, it is either sync-ing
the users or it is sync-ing the groups"""
__tablename__ = "permission_sync_run"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
# Not strictly needed but makes it easy to use without fetching from cc_pair
source_type: Mapped[DocumentSource] = mapped_column(
Enum(DocumentSource, native_enum=False)
)
# Currently all sync jobs are handled as a group permission sync or a user permission sync
update_type: Mapped[PermissionSyncJobType] = mapped_column(
Enum(PermissionSyncJobType)
)
cc_pair_id: Mapped[int | None] = mapped_column(
ForeignKey("connector_credential_pair.id"), nullable=True
)
status: Mapped[PermissionSyncStatus] = mapped_column(Enum(PermissionSyncStatus))
error_msg: Mapped[str | None] = mapped_column(Text, default=None)
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
cc_pair: Mapped[ConnectorCredentialPair] = relationship("ConnectorCredentialPair")
class ExternalPermission(Base):
class User__ExternalUserGroupId(Base):
"""Maps user info both internal and external to the name of the external group
This maps the user to all of their external groups so that the external group name can be
attached to the ACL list matching during query time. User level permissions can be handled by
directly adding the Danswer user to the doc ACL list"""
__tablename__ = "external_permission"
__tablename__ = "user__external_user_group_id"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
# Email is needed because we want to keep track of users not in Danswer to simplify process
# when the user joins
user_email: Mapped[str] = mapped_column(String)
source_type: Mapped[DocumentSource] = mapped_column(
Enum(DocumentSource, native_enum=False)
)
external_permission_group: Mapped[str] = mapped_column(String)
user = relationship("User")
class EmailToExternalUserCache(Base):
"""A way to map users IDs in the external tool to a user in Danswer or at least an email for
when the user joins. Used as a cache for when fetching external groups which have their own
user ids, this can easily be mapped back to users already known in Danswer without needing
to call external APIs to get the user emails.
This way when groups are updated in the external tool and we need to update the mapping of
internal users to the groups, we can sync the internal users to the external groups they are
part of using this.
Ie. User Chris is part of groups alpha, beta, and we can update this if Chris is no longer
part of alpha in some external tool.
"""
__tablename__ = "email_to_external_user_cache"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
external_user_id: Mapped[str] = mapped_column(String)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
# Email is needed because we want to keep track of users not in Danswer to simplify process
# when the user joins
user_email: Mapped[str] = mapped_column(String)
source_type: Mapped[DocumentSource] = mapped_column(
Enum(DocumentSource, native_enum=False)
)
user = relationship("User")
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
# These group ids have been prefixed by the source type
external_user_group_id: Mapped[str] = mapped_column(String, primary_key=True)
cc_pair_id: Mapped[int] = mapped_column(ForeignKey("connector_credential_pair.id"))
class UsageReport(Base):
@@ -1721,7 +1741,7 @@ class UsageReport(Base):
# if None, report was auto-generated
requestor_user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id"), nullable=True
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()

View File

@@ -1,4 +1,5 @@
from collections.abc import Sequence
from datetime import datetime
from functools import lru_cache
from uuid import UUID
@@ -178,6 +179,7 @@ def create_update_persona(
except ValueError as e:
logger.exception("Failed to create persona")
raise HTTPException(status_code=400, detail=str(e))
return PersonaSnapshot.from_model(persona)
@@ -210,6 +212,22 @@ def update_persona_shared_users(
)
def update_persona_public_status(
persona_id: int,
is_public: bool,
db_session: Session,
user: User | None,
) -> None:
persona = fetch_persona_by_id(
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
)
if user and user.role != UserRole.ADMIN and persona.user_id != user.id:
raise ValueError("You don't have permission to modify this persona")
persona.is_public = is_public
db_session.commit()
def get_prompts(
user_id: UUID | None,
db_session: Session,
@@ -242,7 +260,7 @@ def get_personas(
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable)
if not include_default:
stmt = stmt.where(Persona.default_persona.is_(False))
stmt = stmt.where(Persona.builtin_persona.is_(False))
if not include_slack_bot_personas:
stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX)))
if not include_deleted:
@@ -290,7 +308,7 @@ def mark_delete_persona_by_name(
) -> None:
stmt = (
update(Persona)
.where(Persona.name == persona_name, Persona.default_persona == is_default)
.where(Persona.name == persona_name, Persona.builtin_persona == is_default)
.values(deleted=True)
)
@@ -390,7 +408,6 @@ def upsert_persona(
document_set_ids: list[int] | None = None,
tool_ids: list[int] | None = None,
persona_id: int | None = None,
default_persona: bool = False,
commit: bool = True,
icon_color: str | None = None,
icon_shape: int | None = None,
@@ -398,6 +415,9 @@ def upsert_persona(
display_priority: int | None = None,
is_visible: bool = True,
remove_image: bool | None = None,
search_start_date: datetime | None = None,
builtin_persona: bool = False,
is_default_persona: bool = False,
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
chunks_below: int = CONTEXT_CHUNKS_BELOW,
) -> Persona:
@@ -438,8 +458,8 @@ def upsert_persona(
validate_persona_tools(tools)
if persona:
if not default_persona and persona.default_persona:
raise ValueError("Cannot update default persona with non-default.")
if not builtin_persona and persona.builtin_persona:
raise ValueError("Cannot update builtin persona with non-builtin.")
# this checks if the user has permission to edit the persona
persona = fetch_persona_by_id(
@@ -454,7 +474,7 @@ def upsert_persona(
persona.llm_relevance_filter = llm_relevance_filter
persona.llm_filter_extraction = llm_filter_extraction
persona.recency_bias = recency_bias
persona.default_persona = default_persona
persona.builtin_persona = builtin_persona
persona.llm_model_provider_override = llm_model_provider_override
persona.llm_model_version_override = llm_model_version_override
persona.starter_messages = starter_messages
@@ -466,6 +486,8 @@ def upsert_persona(
persona.uploaded_image_id = uploaded_image_id
persona.display_priority = display_priority
persona.is_visible = is_visible
persona.search_start_date = search_start_date
persona.is_default_persona = is_default_persona
# Do not delete any associations manually added unless
# a new updated list is provided
@@ -493,7 +515,7 @@ def upsert_persona(
llm_relevance_filter=llm_relevance_filter,
llm_filter_extraction=llm_filter_extraction,
recency_bias=recency_bias,
default_persona=default_persona,
builtin_persona=builtin_persona,
prompts=prompts or [],
document_sets=document_sets or [],
llm_model_provider_override=llm_model_provider_override,
@@ -505,6 +527,8 @@ def upsert_persona(
uploaded_image_id=uploaded_image_id,
display_priority=display_priority,
is_visible=is_visible,
search_start_date=search_start_date,
is_default_persona=is_default_persona,
)
db_session.add(persona)
@@ -534,7 +558,7 @@ def delete_old_default_personas(
Need a more graceful fix later or those need to never have IDs"""
stmt = (
update(Persona)
.where(Persona.default_persona, Persona.id > 0)
.where(Persona.builtin_persona, Persona.id > 0)
.values(deleted=True, name=func.concat(Persona.name, "_old"))
)
@@ -551,6 +575,7 @@ def update_persona_visibility(
persona = fetch_persona_by_id(
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
)
persona.is_visible = is_visible
db_session.commit()
@@ -563,13 +588,15 @@ def validate_persona_tools(tools: list[Tool]) -> None:
)
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> Sequence[Prompt]:
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]:
"""Unsafe, can fetch prompts from all users"""
if not prompt_ids:
return []
prompts = db_session.scalars(select(Prompt).where(Prompt.id.in_(prompt_ids))).all()
prompts = db_session.scalars(
select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False))
).all()
return prompts
return list(prompts)
def get_prompt_by_id(
@@ -650,9 +677,7 @@ def get_persona_by_id(
result = db_session.execute(persona_stmt)
persona = result.scalar_one_or_none()
if persona is None:
raise ValueError(
f"Persona with ID {persona_id} does not exist or does not belong to user"
)
raise ValueError(f"Persona with ID {persona_id} does not exist")
return persona
# or check if user owns persona
@@ -715,7 +740,7 @@ def delete_persona_by_name(
persona_name: str, db_session: Session, is_default: bool = True
) -> None:
stmt = delete(Persona).where(
Persona.name == persona_name, Persona.default_persona == is_default
Persona.name == persona_name, Persona.builtin_persona == is_default
)
db_session.execute(stmt)

View File

@@ -1,3 +1,5 @@
from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -13,10 +15,12 @@ from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.llm import fetch_embedding_provider
from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexModelStatus
from danswer.db.models import SearchSettings
from danswer.indexing.models import IndexingSetting
from danswer.natural_language_processing.search_nlp_models import clean_model_name
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
from danswer.search.models import SavedSearchSettings
from danswer.server.manage.embedding.models import (
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
@@ -89,6 +93,30 @@ def get_current_db_embedding_provider(
return current_embedding_provider
def delete_search_settings(db_session: Session, search_settings_id: int) -> None:
current_settings = get_current_search_settings(db_session)
if current_settings.id == search_settings_id:
raise ValueError("Cannot delete currently active search settings")
# First, delete associated index attempts
index_attempts_query = delete(IndexAttempt).where(
IndexAttempt.search_settings_id == search_settings_id
)
db_session.execute(index_attempts_query)
# Then, delete the search settings
search_settings_query = delete(SearchSettings).where(
and_(
SearchSettings.id == search_settings_id,
SearchSettings.status != IndexModelStatus.PRESENT,
)
)
db_session.execute(search_settings_query)
db_session.commit()
def get_current_search_settings(db_session: Session) -> SearchSettings:
query = (
select(SearchSettings)
@@ -115,6 +143,13 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None:
return latest_settings
def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
query = select(SearchSettings).order_by(SearchSettings.id.desc())
result = db_session.execute(query)
all_settings = result.scalars().all()
return list(all_settings)
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
if db_session is None:
with Session(get_sqlalchemy_engine()) as db_session:
@@ -146,6 +181,14 @@ def update_current_search_settings(
logger.warning("No current search settings found to update")
return
# Whenever we update the current search settings, we should ensure that the local reranking model is warmed up.
if (
search_settings.rerank_provider_type is None
and search_settings.rerank_model_name is not None
and current_settings.rerank_model_name != search_settings.rerank_model_name
):
warm_up_cross_encoder(search_settings.rerank_model_name)
update_search_settings(current_settings, search_settings, preserved_fields)
db_session.commit()
logger.info("Current search settings updated successfully")
@@ -234,6 +277,7 @@ def get_old_default_embedding_model() -> IndexingSetting:
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
index_name="danswer_chunk",
multipass_indexing=False,
api_url=None,
)
@@ -246,4 +290,5 @@ def get_new_default_embedding_model() -> IndexingSetting:
passage_prefix=ASYM_PASSAGE_PREFIX,
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
multipass_indexing=False,
api_url=None,
)

View File

@@ -1,4 +1,5 @@
from collections.abc import Sequence
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -14,8 +15,11 @@ from danswer.db.models import User
from danswer.db.persona import get_default_prompt
from danswer.db.persona import mark_persona_as_deleted
from danswer.db.persona import upsert_persona
from danswer.db.standard_answer import fetch_standard_answer_categories_by_ids
from danswer.search.enums import RecencyBiasSetting
from danswer.utils.errors import EERequiredError
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
def _build_persona_name(channel_names: list[str]) -> str:
@@ -62,7 +66,7 @@ def create_slack_bot_persona(
llm_model_version_override=None,
starter_messages=None,
is_public=True,
default_persona=False,
is_default_persona=False,
db_session=db_session,
commit=False,
)
@@ -70,6 +74,10 @@ def create_slack_bot_persona(
return persona
def _no_ee_standard_answer_categories(*args: Any, **kwargs: Any) -> list:
return []
def insert_slack_bot_config(
persona_id: int | None,
channel_config: ChannelConfig,
@@ -78,14 +86,29 @@ def insert_slack_bot_config(
enable_auto_filters: bool,
db_session: Session,
) -> SlackBotConfig:
existing_standard_answer_categories = fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=standard_answer_category_ids,
db_session=db_session,
)
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
raise ValueError(
f"Some or all categories with ids {standard_answer_category_ids} do not exist"
versioned_fetch_standard_answer_categories_by_ids = (
fetch_versioned_implementation_with_fallback(
"danswer.db.standard_answer",
"fetch_standard_answer_categories_by_ids",
_no_ee_standard_answer_categories,
)
)
existing_standard_answer_categories = (
versioned_fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=standard_answer_category_ids,
db_session=db_session,
)
)
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
if len(existing_standard_answer_categories) == 0:
raise EERequiredError(
"Standard answers are a paid Enterprise Edition feature - enable EE or remove standard answer categories"
)
else:
raise ValueError(
f"Some or all categories with ids {standard_answer_category_ids} do not exist"
)
slack_bot_config = SlackBotConfig(
persona_id=persona_id,
@@ -117,9 +140,18 @@ def update_slack_bot_config(
f"Unable to find slack bot config with ID {slack_bot_config_id}"
)
existing_standard_answer_categories = fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=standard_answer_category_ids,
db_session=db_session,
versioned_fetch_standard_answer_categories_by_ids = (
fetch_versioned_implementation_with_fallback(
"danswer.db.standard_answer",
"fetch_standard_answer_categories_by_ids",
_no_ee_standard_answer_categories,
)
)
existing_standard_answer_categories = (
versioned_fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=standard_answer_category_ids,
db_session=db_session,
)
)
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
raise ValueError(

View File

@@ -44,12 +44,11 @@ def get_latest_task_by_type(
def register_task(
task_id: str,
task_name: str,
db_session: Session,
) -> TaskQueueState:
new_task = TaskQueueState(
task_id=task_id, task_name=task_name, status=TaskStatus.PENDING
task_id="", task_name=task_name, status=TaskStatus.PENDING
)
db_session.add(new_task)

View File

@@ -5,6 +5,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import Tool
from danswer.server.features.tool.models import Header
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -25,6 +26,7 @@ def create_tool(
name: str,
description: str | None,
openapi_schema: dict[str, Any] | None,
custom_headers: list[Header] | None,
user_id: UUID | None,
db_session: Session,
) -> Tool:
@@ -33,6 +35,9 @@ def create_tool(
description=description,
in_code_tool_id=None,
openapi_schema=openapi_schema,
custom_headers=[header.dict() for header in custom_headers]
if custom_headers
else [],
user_id=user_id,
)
db_session.add(new_tool)
@@ -45,6 +50,7 @@ def update_tool(
name: str | None,
description: str | None,
openapi_schema: dict[str, Any] | None,
custom_headers: list[Header] | None,
user_id: UUID | None,
db_session: Session,
) -> Tool:
@@ -60,6 +66,8 @@ def update_tool(
tool.openapi_schema = openapi_schema
if user_id is not None:
tool.user_id = user_id
if custom_headers is not None:
tool.custom_headers = [header.dict() for header in custom_headers]
db_session.commit()
return tool

View File

@@ -1,9 +1,12 @@
from collections.abc import Sequence
from uuid import UUID
from fastapi_users.password import PasswordHelper
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.auth.schemas import UserRole
from danswer.db.models import User
@@ -20,8 +23,23 @@ def list_users(
return db_session.scalars(stmt).unique().all()
def get_users_by_emails(
db_session: Session, emails: list[str]
) -> tuple[list[User], list[str]]:
# Use distinct to avoid duplicates
stmt = select(User).filter(User.email.in_(emails)) # type: ignore
found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list
found_users_emails = [user.email for user in found_users]
missing_user_emails = [email for email in emails if email not in found_users_emails]
return found_users, missing_user_emails
def get_user_by_email(email: str, db_session: Session) -> User | None:
user = db_session.query(User).filter(User.email == email).first() # type: ignore
user = (
db_session.query(User)
.filter(func.lower(User.email) == func.lower(email))
.first()
)
return user
@@ -30,3 +48,52 @@ def fetch_user_by_id(db_session: Session, user_id: UUID) -> User | None:
user = db_session.query(User).filter(User.id == user_id).first() # type: ignore
return user
def _generate_non_web_user(email: str) -> User:
fastapi_users_pw_helper = PasswordHelper()
password = fastapi_users_pw_helper.generate()
hashed_pass = fastapi_users_pw_helper.hash(password)
return User(
email=email,
hashed_password=hashed_pass,
has_web_login=False,
role=UserRole.BASIC,
)
def add_non_web_user_if_not_exists(db_session: Session, email: str) -> User:
user = get_user_by_email(email, db_session)
if user is not None:
return user
user = _generate_non_web_user(email=email)
db_session.add(user)
db_session.commit()
return user
def add_non_web_user_if_not_exists__no_commit(db_session: Session, email: str) -> User:
user = get_user_by_email(email, db_session)
if user is not None:
return user
user = _generate_non_web_user(email=email)
db_session.add(user)
db_session.flush() # generate id
return user
def batch_add_non_web_user_if_not_exists__no_commit(
db_session: Session, emails: list[str]
) -> list[User]:
found_users, missing_user_emails = get_users_by_emails(db_session, emails)
new_users: list[User] = []
for email in missing_user_emails:
new_users.append(_generate_non_web_user(email=email))
db_session.add_all(new_users)
db_session.flush() # generate ids
return found_users + new_users

View File

@@ -177,6 +177,30 @@ class Updatable(abc.ABC):
- Whether the document is hidden or not, hidden documents are not returned from search
"""
@abc.abstractmethod
def update_single(self, update_request: UpdateRequest) -> None:
"""
Updates some set of chunks for a document. The document and fields to update
are specified in the update request. Each update request in the list applies
its changes to a list of document ids.
None values mean that the field does not need an update.
The rationale for a single update function is that it allows retries and parallelism
to happen at a higher / more strategic level, is simpler to read, and allows
us to individually handle error conditions per document.
Parameters:
- update_request: for a list of document ids in the update request, apply the same updates
to all of the documents with those ids.
Return:
- an HTTPStatus code. The code can used to decide whether to fail immediately,
retry, etc. Although this method likely hits an HTTP API behind the
scenes, the usage of HTTPStatus is a convenience and the interface is not
actually HTTP specific.
"""
raise NotImplementedError
@abc.abstractmethod
def update(self, update_requests: list[UpdateRequest]) -> None:
"""

View File

@@ -26,6 +26,17 @@
<disk>0.75</disk>
</resource-limits>
</tuning>
<engine>
<proton>
<tuning>
<searchnode>
<requestthreads>
<persearch>SEARCH_THREAD_NUMBER</persearch>
</requestthreads>
</searchnode>
</tuning>
</proton>
</engine>
<config name="vespa.config.search.summary.juniperrc">
<max_matches>3</max_matches>
<length>750</length>
@@ -33,4 +44,4 @@
<min_length>300</min_length>
</config>
</content>
</services>
</services>

View File

@@ -30,6 +30,7 @@ from danswer.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from danswer.document_index.vespa_constants import HIDDEN
from danswer.document_index.vespa_constants import LARGE_CHUNK_REFERENCE_IDS
from danswer.document_index.vespa_constants import MAX_ID_SEARCH_QUERY_SIZE
from danswer.document_index.vespa_constants import MAX_OR_CONDITIONS
from danswer.document_index.vespa_constants import METADATA
from danswer.document_index.vespa_constants import METADATA_SUFFIX
from danswer.document_index.vespa_constants import PRIMARY_OWNERS
@@ -292,12 +293,11 @@ def query_vespa(
if LOG_VESPA_TIMING_INFORMATION
else {},
)
response = requests.post(
SEARCH_ENDPOINT,
json=params,
)
try:
response = requests.post(
SEARCH_ENDPOINT,
json=params,
)
response.raise_for_status()
except requests.HTTPError as e:
request_info = f"Headers: {response.request.headers}\nPayload: {params}"
@@ -319,6 +319,12 @@ def query_vespa(
logger.debug("Vespa timing info: %s", response_json.get("timing"))
hits = response_json["root"].get("children", [])
if not hits:
logger.warning(
f"No hits found for YQL Query: {query_params.get('yql', 'No YQL Query')}"
)
logger.debug(f"Vespa Response: {response.text}")
for hit in hits:
if hit["fields"].get(CONTENT) is None:
identifier = hit["fields"].get("documentid") or hit["id"]
@@ -379,7 +385,7 @@ def batch_search_api_retrieval(
capped_requests: list[VespaChunkRequest] = []
uncapped_requests: list[VespaChunkRequest] = []
chunk_count = 0
for request in chunk_requests:
for req_ind, request in enumerate(chunk_requests, start=1):
# All requests without a chunk range are uncapped
# Uncapped requests are retrieved using the Visit API
range = request.range
@@ -387,9 +393,10 @@ def batch_search_api_retrieval(
uncapped_requests.append(request)
continue
# If adding the range to the chunk count is greater than the
# max query size, we need to perform a retrieval to avoid hitting the limit
if chunk_count + range > MAX_ID_SEARCH_QUERY_SIZE:
if (
chunk_count + range > MAX_ID_SEARCH_QUERY_SIZE
or req_ind % MAX_OR_CONDITIONS == 0
):
retrieved_chunks.extend(
_get_chunks_via_batch_search(
index_name=index_name,

View File

@@ -16,6 +16,7 @@ import requests
from danswer.configs.chat_configs import DOC_TIME_DECAY
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.chat_configs import TITLE_CONTENT_RATIO
from danswer.configs.chat_configs import VESPA_SEARCHER_THREADS
from danswer.configs.constants import KV_REINDEX_KEY
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import DocumentInsertionRecord
@@ -52,6 +53,7 @@ from danswer.document_index.vespa_constants import DOCUMENT_REPLACEMENT_PAT
from danswer.document_index.vespa_constants import DOCUMENT_SETS
from danswer.document_index.vespa_constants import HIDDEN
from danswer.document_index.vespa_constants import NUM_THREADS
from danswer.document_index.vespa_constants import SEARCH_THREAD_NUMBER_PAT
from danswer.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT
from danswer.document_index.vespa_constants import VESPA_DIM_REPLACEMENT_PAT
from danswer.document_index.vespa_constants import VESPA_TIMEOUT
@@ -118,7 +120,7 @@ class VespaIndex(DocumentIndex):
secondary_index_embedding_dim: int | None,
) -> None:
deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate"
logger.debug(f"Sending Vespa zip to {deploy_url}")
logger.info(f"Deploying Vespa application package to {deploy_url}")
vespa_schema_path = os.path.join(
os.getcwd(), "danswer", "document_index", "vespa", "app_config"
@@ -134,6 +136,10 @@ class VespaIndex(DocumentIndex):
doc_lines = _create_document_xml_lines(schema_names)
services = services_template.replace(DOCUMENT_REPLACEMENT_PAT, doc_lines)
services = services.replace(
SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS)
)
kv_store = get_dynamic_config_store()
needs_reindexing = False
@@ -282,7 +288,7 @@ class VespaIndex(DocumentIndex):
raise requests.HTTPError(failure_msg) from e
def update(self, update_requests: list[UpdateRequest]) -> None:
logger.info(f"Updating {len(update_requests)} documents in Vespa")
logger.debug(f"Updating {len(update_requests)} documents in Vespa")
# Handle Vespa character limitations
# Mutating update_requests but it's not used later anyway
@@ -371,6 +377,91 @@ class VespaIndex(DocumentIndex):
time.monotonic() - update_start,
)
def update_single(self, update_request: UpdateRequest) -> None:
"""Note: if the document id does not exist, the update will be a no-op and the
function will complete with no errors or exceptions.
Handle other exceptions if you wish to implement retry behavior
"""
if len(update_request.document_ids) != 1:
raise ValueError("update_request must contain a single document id")
# Handle Vespa character limitations
# Mutating update_request but it's not used later anyway
update_request.document_ids = [
replace_invalid_doc_id_characters(doc_id)
for doc_id in update_request.document_ids
]
# update_start = time.monotonic()
# Fetch all chunks for each document ahead of time
index_names = [self.index_name]
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
chunk_id_start_time = time.monotonic()
all_doc_chunk_ids: list[str] = []
for index_name in index_names:
for document_id in update_request.document_ids:
# this calls vespa and can raise http exceptions
doc_chunk_ids = get_all_vespa_ids_for_document_id(
document_id=document_id,
index_name=index_name,
filters=None,
get_large_chunks=True,
)
all_doc_chunk_ids.extend(doc_chunk_ids)
logger.debug(
f"Took {time.monotonic() - chunk_id_start_time:.2f} seconds to fetch all Vespa chunk IDs"
)
# Build the _VespaUpdateRequest objects
update_dict: dict[str, dict] = {"fields": {}}
if update_request.boost is not None:
update_dict["fields"][BOOST] = {"assign": update_request.boost}
if update_request.document_sets is not None:
update_dict["fields"][DOCUMENT_SETS] = {
"assign": {
document_set: 1 for document_set in update_request.document_sets
}
}
if update_request.access is not None:
update_dict["fields"][ACCESS_CONTROL_LIST] = {
"assign": {acl_entry: 1 for acl_entry in update_request.access.to_acl()}
}
if update_request.hidden is not None:
update_dict["fields"][HIDDEN] = {"assign": update_request.hidden}
if not update_dict["fields"]:
logger.error("Update request received but nothing to update")
return
processed_update_requests: list[_VespaUpdateRequest] = []
for document_id in update_request.document_ids:
for doc_chunk_id in all_doc_chunk_ids:
processed_update_requests.append(
_VespaUpdateRequest(
document_id=document_id,
url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}/{doc_chunk_id}",
update_request=update_dict,
)
)
with httpx.Client(http2=True) as http_client:
for update in processed_update_requests:
http_client.put(
update.url,
headers={"Content-Type": "application/json"},
json=update.update_request,
)
# logger.debug(
# "Finished updating Vespa documents in %.2f seconds",
# time.monotonic() - update_start,
# )
return
def delete(self, doc_ids: list[str]) -> None:
logger.info(f"Deleting {len(doc_ids)} documents from Vespa")

View File

@@ -162,14 +162,16 @@ def _index_vespa_chunk(
METADATA_SUFFIX: chunk.metadata_suffix_keyword,
EMBEDDINGS: embeddings_name_vector_map,
TITLE_EMBEDDING: chunk.title_embedding,
BOOST: chunk.boost,
DOC_UPDATED_AT: _vespa_get_updated_at_attribute(document.doc_updated_at),
PRIMARY_OWNERS: get_experts_stores_representations(document.primary_owners),
SECONDARY_OWNERS: get_experts_stores_representations(document.secondary_owners),
# the only `set` vespa has is `weightedset`, so we have to give each
# element an arbitrary weight
# rkuo: acl, docset and boost metadata are also updated through the metadata sync queue
# which only calls VespaIndex.update
ACCESS_CONTROL_LIST: {acl_entry: 1 for acl_entry in chunk.access.to_acl()},
DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets},
BOOST: chunk.boost,
}
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_chunk_id}"

View File

@@ -7,6 +7,7 @@ from danswer.configs.constants import SOURCE_TYPE
VESPA_DIM_REPLACEMENT_PAT = "VARIABLE_DIM"
DANSWER_CHUNK_REPLACEMENT_PAT = "DANSWER_CHUNK_NAME"
DOCUMENT_REPLACEMENT_PAT = "DOCUMENT_REPLACEMENT"
SEARCH_THREAD_NUMBER_PAT = "SEARCH_THREAD_NUMBER"
DATE_REPLACEMENT = "DATE_REPLACEMENT"
# config server
@@ -25,6 +26,9 @@ NUM_THREADS = (
32 # since Vespa doesn't allow batching of inserts / updates, we use threads
)
MAX_ID_SEARCH_QUERY_SIZE = 400
# Suspect that adding too many "or" conditions will cause Vespa to timeout and return
# an empty list of hits (with no error status and coverage: 0 and degraded)
MAX_OR_CONDITIONS = 10
# up from 500ms for now, since we've seen quite a few timeouts
# in the long term, we are looking to improve the performance of Vespa
# so that we can bring this back to default

Some files were not shown because too many files have changed in this diff Show More