Compare commits

..

129 Commits

Author SHA1 Message Date
pablodanswer
8b9e1a07d5 typing 2024-11-11 09:26:46 -08:00
pablodanswer
b6301ffcb9 spacing 2024-11-11 09:05:01 -08:00
pablodanswer
490ce0db18 cleaner approach 2024-11-11 09:03:49 -08:00
pablodanswer
b2ca13eaae treat async values differently 2024-11-11 08:59:16 -08:00
Chris Weaver
ba805f766f New assistants api (#3097) 2024-11-11 07:55:23 -08:00
rkuo-danswer
9d57f34c34 re-enable helm (#3053)
* re-enable helm

* allow manual triggering

* change vespa host

* change vespa chart location

* update Chart.lock

* update ct.yaml with new vespa chart repo

* bump vespa to 0.2.5

* update Chart.lock

* update to vespa 0.2.6

* bump vespa to 0.2.7

* bump to 0.2.8

* bump version

* try appending the ordinal

* try new configmap

* bump vespa

* bump vespa

* add debug to see if we can figure out what ct install thinks is failing

* add debug flag to helm

* try disabling nginx because of KinD

* use helm-extra-set-args

* try command line

* try pointing test connection to the correct service name

* bump vespa to 0.2.12

* update chart.lock

* bump vespa to 0.2.13

* bump vespa to 0.2.14

* bump vespa

* bump vespa

* re-enable chart testing only on changes

* name the check more specifically than "lint-test"

* add some debugging

* try setting remote

* might have to specify chart dirs directly

* add comments

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-10 01:28:39 +00:00
pablodanswer
cc2f584321 Silence auth logs (#3098)
* silence auth logs

* remove unnecessary line

* k
2024-11-09 21:41:11 +00:00
pablodanswer
a1b95df3b8 Robustify cloud deployment + include initial KEDA configuration (#3094)
* robustify cloud deployment + include initial KEDA configuration

* ensure .github changes are passed

* raise exits
2024-11-09 21:26:51 +00:00
pablodanswer
9272d6ebfe Remove ee (#3093)
* move api key to non-ee

* finalize previous migration

* move token rate limit to non-ee

* general cleanup

* update

* update

* finalize

* finalize

* ensure callable

* k
2024-11-09 20:51:36 +00:00
Yuhong Sun
4fb65dcf73 Reenable OpenAI Tokenizer (#3062)
* k

* clean up test embeddings

* nit

* minor update to ensure consistency

* minor organizational update

* minor updates

---------

Co-authored-by: pablodanswer <pablo@danswer.ai>
2024-11-08 22:54:15 +00:00
rkuo-danswer
2bbc5d5d07 fix saving docker logs (#3090) 2024-11-08 19:54:48 +00:00
rkuo-danswer
950b1c38f2 Merge pull request #3080 from danswer-ai/robust_assistant_description
Account for malformatted starter messages
2024-11-08 11:28:19 -08:00
Yuhong Sun
99fbfba32f File Connector Metadata (#3089) 2024-11-08 10:49:59 -08:00
pablodanswer
0a59efe64a account for malformatted starter messages 2024-11-08 10:21:04 -08:00
pablodanswer
cf5d394d39 adjust default postgres schema for slack listener (#3088) 2024-11-08 18:00:44 +00:00
pablodanswer
f6d8f5ca89 Migrate tenant upgrades to data plane (#3051)
* add provisioning on data plane

* functional but scrappy

* minor cleanup

* minor clean up

* k

* simplify

* update provisioning

* improve import logic

* ensure proper conditional

* minor pydantic update

* minor config update

* nit
2024-11-08 17:13:29 +00:00
hagen-danswer
1fb4cdfcc3 Merge pull request #3073 from skylares/fireflies-dev
Fireflies connector
2024-11-08 06:50:22 -08:00
hagen-danswer
ac51469bcb Merge branch 'main' into fireflies-dev 2024-11-07 18:56:37 -08:00
Skylar Kesselring
c25f164e28 Remove linux 2024-11-07 21:51:58 -05:00
Skylar Kesselring
813720905b Fix failure cases 2024-11-07 21:37:41 -05:00
rkuo-danswer
0c45488ac6 wait for db before allowing worker to proceed (reduces error spam on … (#3079)
* wait for db before allowing worker to proceed (reduces error spam on container startup)

* fix session usage

* rework readiness probe logic to be less confusing and word ongoing probes better

* add vespa probe too

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-08 01:25:09 +00:00
Skylar Kesselring
95d9b33c1a Clean up connector 2024-11-07 19:51:40 -05:00
Yuhong Sun
55919f596c PG Dev Max Connections (#3082) 2024-11-07 11:51:23 -08:00
pablodanswer
1d0fb6d012 Evaluate None to default (#3069)
* add sentinel value

* update typing

* clearer

* update comments

* ensure proper attribution
2024-11-07 18:41:42 +00:00
pablodanswer
2b1dbde829 minor improvements (#3081) 2024-11-07 18:35:49 +00:00
hagen-danswer
2758ffd9d5 Google Drive Improvements (#3057)
* Google Drive Improvements

* mypy

* should work!

* variable cleanup

* final fixes
2024-11-07 02:07:35 +00:00
pablodanswer
07a1b49b4f update persona defaults (#3042)
* evaluate None to default

* fix usage report pagination

* update persona defaults

* update user preferences

* k

* validate

* update typing

* nit

* formating nits

* fallback to all assistants

* update ux + spacing

* udpate refresh logic

* minor update to refresh

* nit

* touchup

* update starter message

* update default live assistant logic

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2024-11-07 00:03:14 +00:00
pablodanswer
43d8daa5bc update redirect 2024-11-06 14:55:32 -08:00
hagen-danswer
faeb9f09f0 Merge pull request #3008 from danswer-ai/horizontal_slack
Add Functional Horizontal scaling for Slack
2024-11-06 14:31:13 -08:00
pablodanswer
25f5c12750 remove print 2024-11-06 13:49:16 -08:00
pablodanswer
2d81710ccc minor udpate 2024-11-06 13:49:16 -08:00
pablodanswer
187a7d2da2 validated approach 2024-11-06 13:49:16 -08:00
pablodanswer
4b152aa3a7 update slack 2024-11-06 13:49:16 -08:00
pablodanswer
06f937cf93 no typing 2024-11-06 13:49:16 -08:00
pablodanswer
5a24ed2947 updated cleanup 2024-11-06 13:49:16 -08:00
pablodanswer
2372e6a5a5 update slack 2024-11-06 13:49:15 -08:00
pablodanswer
3eef4e3992 functioning 2024-11-06 13:47:47 -08:00
pablodanswer
467ce4e3f3 fix usage report pagination 2024-11-06 13:21:00 -08:00
Skylar Kesselring
ee4b334a0a Fix errors and cleanup 2024-11-06 14:01:51 -05:00
pablodanswer
4087292001 evaluate None to default 2024-11-06 09:36:43 -08:00
rkuo-danswer
da6ed5b2b3 Merge pull request #3066 from danswer-ai/bugfix/log-vespa-url
need to see vespa url for container debugging
2024-11-06 00:35:10 -08:00
Richard Kuo
864ac2ac5c need to see vespa url for container debugging 2024-11-06 00:26:55 -08:00
rkuo-danswer
12cb77c80e Merge pull request #3059 from danswer-ai/bugfix/sentry_indexing
add sentry to spawned indexing task
2024-11-05 16:51:23 -08:00
Richard Kuo (Danswer)
583cd14bf4 comment why we need sentry here 2024-11-05 16:46:50 -08:00
Richard Kuo (Danswer)
001fcb3359 fix stale indexing tasks being allowed to run after a restart 2024-11-05 16:39:54 -08:00
Skylar Kesselring
7ff18e0a93 Create connector 2024-11-05 19:28:57 -05:00
Richard Kuo (Danswer)
9ac256e925 Merge branch 'main' of https://github.com/danswer-ai/danswer into bugfix/sentry_indexing 2024-11-05 15:48:23 -08:00
hagen-danswer
08600db41d Merge pull request #3056 from danswer-ai/form_stretch
Improve form
2024-11-05 14:19:11 -08:00
rkuo-danswer
6bf06ac7f7 limit session scope of index attempt (use id's where appropriate as w… (#3049)
* limit session scope of index attempt (use id's where appropriate as well)

* fix session scope

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-05 20:51:43 +00:00
Richard Kuo (Danswer)
5b06b53a3e add sentry to spawned indexing task 2024-11-05 12:30:21 -08:00
pablodanswer
afce57b29f clarity 2024-11-05 10:44:12 -08:00
pablodanswer
257dbecd1d k 2024-11-05 10:24:48 -08:00
pablodanswer
bd6baf39c3 update 2024-11-05 10:23:52 -08:00
pablodanswer
b2c55ebd71 ensure props aligned (#3050)
* ensure props aligned

* k

* k
2024-11-05 16:49:04 +00:00
pablodanswer
dea7a8f697 Clean up tooltips (#3047)
* clean up tooltips

* nit: fix delay duration
2024-11-05 16:48:19 +00:00
pablodanswer
ddae2346ec form 2024-11-05 08:33:03 -08:00
Weves
9032fb4467 Improve background token refresh 2024-11-04 15:00:16 -08:00
rkuo-danswer
b6ecbbcf45 add to async get session as well (#3046) 2024-11-04 20:47:56 +00:00
pablodanswer
1d8e662b79 ensure we reset all (#3048) 2024-11-04 19:48:15 +00:00
pablodanswer
2cb33b1fb4 add default api keys for cloud users (#3044)
* add default api keys for cloud users

* add cohere as well

* naming
2024-11-04 19:11:12 +00:00
hagen-danswer
2cd1e6be00 gmail refactor + permission syncing (#3021)
* initial frontend changes and shared google refactoring

* gmail connector is reworked

* added permission syncing for gmail

* tested!

* Added tests for gmail connector

* fixed tests and mypy

* temp fix

* testing done!

* rename

* test fixes maybe?

* removed irrelevant tests

* anotha one

* refactoring changes

* refactor finished

* maybe these fixes work

* dumps

* final fixes
2024-11-04 18:06:23 +00:00
Weves
8e55566f66 Fix slack bot form + LLM provider form 2024-11-03 17:51:04 -08:00
pablodanswer
bafb95d920 Misc color clean up (#3026)
* misc color clean up

* additional nits

* nit

* nit

* additional minor nits

* ensure tailwind config evaluates properly + update textarea -> input

* ensure tool call renders

* formatting
2024-11-03 23:57:11 +00:00
pablodanswer
c6e8bf2d28 add multiple formats to tools (#3041) 2024-11-03 23:54:19 +00:00
Chris Weaver
c2d04f591d Add drive sections (#3040)
* ADd header support for drive

* Fix mypy

* Comment change

* Improve

* Cleanup

* Add comment
2024-11-03 22:10:45 +00:00
rkuo-danswer
56c3a5ff5b add POSTGRES_IDLE_SESSIONS_TIMEOUT (#3019)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-03 21:58:12 +00:00
Yuhong Sun
fac2b100a1 Last Message Too Large Logging (#3039) 2024-11-03 11:24:04 -08:00
pablodanswer
51b79f688a Tool call per message (#3025)
* single tool call per message

* finalize migration

* minor image generation fix

* validate simplify

* k

* remove print

* validated
2024-11-03 10:51:51 -08:00
pablodanswer
a7002dfa1d add CSV display (#3028)
* add CSV display

* add downloading

* restructure

* create portal for modal

* update requirements

* nit
2024-11-03 10:43:05 -08:00
pablodanswer
93d0104d3c slight upgrade to image generation prompts (#3036)
* slight upgrade to prompts

* k

* nit
2024-11-03 10:42:52 -08:00
pablodanswer
46e5ffa3ae add validated + reformatted dynamic beat acquisition (#3006)
* add validated + reformatted dynamic beat acquisition

* validate

* reorg

* nit

* address comments

* update

* typing

* ensure versioned apps capture

* Remove locks (#3017)

* add validated + reformatted dynamic beat acquisition

* initial removal of locks!

* minor

* remove unecessary locks

* update

* nit

* k

* K8s jobs (#3033)

* add k8s configs

* k

* update config

* k

* improved timeouts + worker configs

* improve workers
2024-11-03 10:27:25 -08:00
pablodanswer
d4f38bba8b Revert temporary modifications (#3038)
* Revert temporary modifications

* nit
2024-11-03 10:27:06 -08:00
pablodanswer
19d6b63fd3 temporary update (#3037) 2024-11-03 10:05:33 -08:00
Chris Weaver
938d5788b6 Upgrade to latest NextJS + switch to turbopack (#3027)
* Upgrade to NextJS 15 + use turbopacK

* Remove unintended change

* Update nextjs version

* Remove override

* Upgrade react

* Fix charts

* Style

* Style

* Fix prettier

* slight modification

---------

Co-authored-by: pablodanswer <pablo@danswer.ai>
2024-11-03 02:56:23 +00:00
hagen-danswer
70f703cc0f Merge pull request #3035 from danswer-ai/freshdesk-nit
minor nit
2024-11-02 18:14:52 -07:00
hagen-danswer
8bcf80aa76 minor nit 2024-11-02 18:05:06 -07:00
rkuo-danswer
5f5cc9a724 Feature/redis connector refactor (#2992)
* refactor RedisConnectorDeletion into RedisConnector

* refactor redis stop and deletion

* port pruning

* nest pruning

* port deletion

* port indexing

* refactor into individual files

* refactor redis connector index  to take search settings at init

* move back to debug level log

* refactor doc set and user group (mostly)

* mypy fixes
2024-11-02 19:53:04 +00:00
pablodanswer
e4bb14d4e1 Super user (#2944)
* add super user

* nits
2024-11-02 17:29:23 +00:00
hagen-danswer
5d9b8364ab Merge pull request #3032 from danswer-ai/freshdesk-cleanup
Cleaned up connector
2024-11-02 09:31:22 -07:00
hagen-danswer
83c299ebc8 troll logger statement 2024-11-02 09:09:46 -07:00
hagen-danswer
6b4143cc30 ID fix 2024-11-02 09:08:26 -07:00
hagen-danswer
6e8c88ed71 made id more unique 2024-11-02 09:05:24 -07:00
hagen-danswer
d652cb3141 renamed variables 2024-11-02 09:03:42 -07:00
hagen-danswer
5e444d43f9 Cleaned up connector 2024-11-02 09:01:15 -07:00
hagen-danswer
2e49027beb Merge pull request #2884 from skylares/sky-dev
Add Freshdesk Connector
2024-11-02 08:27:35 -07:00
hagen-danswer
d7bcd32d9a out of scope 2024-11-02 08:21:33 -07:00
hagen-danswer
4a6b8db65f out of scope 2024-11-02 08:20:08 -07:00
hagen-danswer
6f440d126a more mypy fixes 2024-11-02 08:17:53 -07:00
hagen-danswer
013292a0e3 mypy fixes 2024-11-02 08:15:36 -07:00
Richard Kuo
a1ae22ef4a fix run key 2024-11-02 02:23:08 -07:00
Richard Kuo
40beda30a4 try pip-license-checker 2024-11-02 02:20:58 -07:00
Richard Kuo
d3062cacea manual only for now 2024-11-02 00:01:55 -07:00
Richard Kuo
678ed23853 codel permissions? 2024-11-01 22:34:41 -07:00
Richard Kuo
ea2da63cf2 try installing npm deps 2024-11-01 22:09:06 -07:00
Richard Kuo
4fc8a35220 try repo level scan 2024-11-01 21:59:23 -07:00
hagen-danswer
f981106111 Update connector.py 2024-11-01 19:27:03 -07:00
Richard Kuo (Danswer)
5439c33313 don't scan the os packages 2024-11-01 17:24:41 -07:00
Richard Kuo (Danswer)
5e050f8305 we didn't checkout the code, no trivy ignore 2024-11-01 17:16:28 -07:00
Richard Kuo (Danswer)
12c82de78f experimental github action to scan licenses 2024-11-01 17:10:59 -07:00
pablodanswer
645402c71a Tremor -> Shadcn (#2983)
* initialization

* button + input updates

* migrate dividers + buttons

* migrate badges

* minor updates

* migrate cards

* fix compiling

* begin date picker + badge transfer

* remove tremor

* fully swapped

* nits

* list item + configuration updates

* clean build

* update colors

* nits
2024-11-01 23:20:06 +00:00
pablodanswer
772313236f minor foreign key update (#3007) 2024-11-01 21:16:50 +00:00
Chris Weaver
ecf4923a3a Fix answer with specified doc ids (#2703)
* Fix

Fix

Refactor

more

more

fix

refactor

Fix circular imports

Refactor

Move tests around

* Add quote support

* Testing

* More testing

* Fix image generation slowness

* Remove unused exception

* Fix UT

* fix stop generating

* minor typo

* minor logging updates for clarity

---------

Co-authored-by: pablodanswer <pablo@danswer.ai>
2024-11-01 19:50:20 +00:00
pablodanswer
d66b81a902 Feat/certificate (#2998)
* first pass

* simplify

* remove now unneeded COPY command

* minor clean up

* k

* nit
2024-11-01 19:34:52 +00:00
pablodanswer
753293cefb Basic multi tenant api key (#3004)
* basic multi tenant api key

* organization

* nit

* clean
2024-11-01 19:34:51 +00:00
pablodanswer
6d543f3d4f Do not count API keys as users (#3022)
* don't count api keys as users

* typing
2024-11-01 19:34:30 +00:00
hagen-danswer
ccdc09e2d4 Merge pull request #3020 from danswer-ai/gdrive-interface
Add Gdrive Interface
2024-11-01 06:28:56 -07:00
hagen-danswer
4a23c8702d Quicky 2024-11-01 06:27:55 -07:00
rkuo-danswer
dc2dfeb5b8 Fix pywikibot droppings (#2924)
* make pywikibot store its working files in a system provided temp directory

* move the config setting around

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-01 05:59:12 +00:00
hagen-danswer
71d4fb98d3 Refactored Google Drive Connector + Permission Syncing (#2945)
* refactoring changes

* everything working for service account

* works with service account

* combined scopes

* copy change

* oauth prep

* Works for oauth and service account credentials

* mypy

* merge fixes

* Refactor Google Drive connector

* finished backend

* auth changes

* if its stupid but it works, its not stupid

* npm run dev fixes

* addressed change requests

* string fix

* minor fixes and cleanup

* spacing cleanup

* Update connector.py

* everything done

* testing!

* Delete backend/tests/daily/connectors/google_drive/file_generator.py

* cleaned up

---------

Co-authored-by: Chris Weaver <25087905+Weves@users.noreply.github.com>
2024-11-01 02:25:00 +00:00
Yuhong Sun
b34f5862d7 Remove License Issues (#3013)
* k

* k

* k

* k

* k
2024-11-01 00:31:19 +00:00
pablodanswer
0b08bf4e3f Proper tenant reset (#3015)
* add proper tenant reset

* clear comment

* minor formatting
2024-10-31 19:45:35 +00:00
pablodanswer
add87fa1b4 remove endpoint (#3014) 2024-10-31 19:43:15 +00:00
Samarth Mishra
787fdf2e38 Update README.md (#3011) 2024-10-31 10:44:36 -07:00
hagen-danswer
e3be318781 Update connector.py 2024-10-31 09:50:48 -07:00
Skylar Kesselring
73ee709801 Fix typing errors 2024-10-30 17:46:04 -04:00
Skylar Kesselring
53d2d333ab Refactor metadata 2024-10-30 17:23:20 -04:00
Skylar Kesselring
195e2c335d Fix per_page count 2024-10-28 12:35:40 -04:00
Skylar Kesselring
1dec69bb82 Fix document time parsing 2024-10-28 12:33:58 -04:00
Skylar Kesselring
075e4f18bc Clean up & comment fetch_tickets 2024-10-28 11:26:37 -04:00
Skylar Kesselring
e5494f9742 Refactor & cleanup code, process tickets in batches 2024-10-27 11:53:50 -04:00
Skylar Kesselring
e5d84cae1b Clean up code 2024-10-26 23:06:24 -04:00
Skylar Kesselring
8023cafb2b Fixed polling issue with timezone 2024-10-25 23:46:47 -04:00
Skylar Kesselring
a348caa9b1 Add pagination & Remove req.obj from connectors.tsx 2024-10-25 14:12:11 -04:00
Skylar Kesselring
245adc4d3d Remove 2 month time check & Add time range to fetch and process 2024-10-24 12:42:08 -04:00
Skylar Kesselring
4ad35d76b0 Make ticket fetching a seperate function from processing 2024-10-24 12:25:29 -04:00
Skylar Kesselring
cc1e1c178b Replace html processing library with danswer util 2024-10-24 11:49:11 -04:00
Skylar Kesselring
87b5975091 Remove unnecessary log & Add LoadConnector 2024-10-24 11:38:29 -04:00
Skylar Kesselring
85b56e39c9 Fix Freshdesk connector date parsing for UTC timestamps 2024-10-23 14:01:03 -04:00
Skylar Kesselring
a1680fac2f Implement freshdesk frontend 2024-10-23 12:58:15 -04:00
445 changed files with 18649 additions and 7392 deletions

View File

@@ -3,61 +3,61 @@ name: Build and Push Backend Image on Tag
on:
push:
tags:
- '*'
- "*"
env:
REGISTRY_IMAGE: danswer/danswer-backend
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'danswer/danswer-backend-cloud' || 'danswer/danswer-backend' }}
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
# TODO: investigate a matrix build like the web container
# TODO: investigate a matrix build like the web container
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Install build-essential
run: |
sudo apt-get update
sudo apt-get install -y build-essential
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
- name: Install build-essential
run: |
sudo apt-get update
sudo apt-get install -y build-essential
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'
trivyignores: ./backend/.trivyignore
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
with:
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
trivyignores: ./backend/.trivyignore

View File

@@ -4,12 +4,12 @@ name: Build and Push Cloud Web Image on Tag
on:
push:
tags:
- '*'
- "*"
env:
REGISTRY_IMAGE: danswer/danswer-cloud-web-server
REGISTRY_IMAGE: danswer/danswer-web-server-cloud
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build:
runs-on:
@@ -28,11 +28,11 @@ jobs:
- name: Prepare
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
@@ -41,16 +41,16 @@ jobs:
tags: |
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push by digest
id: build
uses: docker/build-push-action@v5
@@ -65,17 +65,17 @@ jobs:
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
# needed due to weird interactions with the builds for different platforms
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
- name: Export digest
run: |
mkdir -p /tmp/digests
digest="${{ steps.build.outputs.digest }}"
touch "/tmp/digests/${digest#sha256:}"
touch "/tmp/digests/${digest#sha256:}"
- name: Upload digest
uses: actions/upload-artifact@v4
with:
@@ -95,42 +95,42 @@ jobs:
path: /tmp/digests
pattern: digests-*
merge-multiple: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Create manifest list and push
working-directory: /tmp/digests
run: |
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
- name: Inspect image
run: |
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
with:
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'
severity: "CRITICAL,HIGH"

View File

@@ -3,53 +3,53 @@ name: Build and Push Model Server Image on Tag
on:
push:
tags:
- '*'
- "*"
env:
REGISTRY_IMAGE: danswer/danswer-model-server
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'danswer/danswer-model-server-cloud' || 'danswer/danswer-model-server' }}
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Model Server Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
- name: Model Server Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
with:
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
severity: "CRITICAL,HIGH"

View File

@@ -0,0 +1,76 @@
# Scan for problematic software licenses
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
name: 'Nightly - Scan licenses'
on:
# schedule:
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
workflow_dispatch: # Allows manual triggering
permissions:
actions: read
contents: read
security-events: write
jobs:
scan-licenses:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- name: Get explicit and transitive dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
pip freeze > requirements-all.txt
- name: Check python
id: license_check_report
uses: pilosus/action-pip-license-checker@v2
with:
requirements: 'requirements-all.txt'
fail: 'Copyleft'
exclude: '(?i)^(pylint|aio[-_]*).*'
- name: Print report
if: ${{ always() }}
run: echo "${{ steps.license_check_report.outputs.report }}"
- name: Install npm dependencies
working-directory: ./web
run: npm ci
- name: Run Trivy vulnerability scanner in repo mode
uses: aquasecurity/trivy-action@0.28.0
with:
scan-type: fs
scanners: license
format: table
# format: sarif
# output: trivy-results.sarif
severity: HIGH,CRITICAL
# - name: Upload Trivy scan results to GitHub Security tab
# uses: github/codeql-action/upload-sarif@v3
# with:
# sarif_file: trivy-results.sarif

View File

@@ -210,17 +210,18 @@ jobs:
echo "All integration tests passed successfully."
fi
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
# save before stopping the containers so the logs can be captured
- name: Save Docker logs
if: success() || failure()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
- name: Upload logs
if: success() || failure()

View File

@@ -1,24 +1,20 @@
# This workflow is intentionally disabled while we're still working on it
# It's close to ready, but a race condition needs to be fixed with
# API server and Vespa startup, and it needs to have a way to build/test against
# local containers
name: Helm - Lint and Test Charts
on:
merge_group:
pull_request:
branches: [ main ]
workflow_dispatch: # Allows manual triggering
jobs:
lint-test:
helm-chart-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
# fetch-depth 0 is required for helm/chart-testing-action
steps:
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
fetch-depth: 0
@@ -28,7 +24,7 @@ jobs:
version: v3.14.4
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: 'pip'
@@ -45,24 +41,31 @@ jobs:
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.6.1
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
- name: Run chart-testing (list-changed)
id: list-changed
run: |
changed=$(ct list-changed --target-branch ${{ github.event.repository.default_branch }})
echo "default_branch: ${{ github.event.repository.default_branch }}"
changed=$(ct list-changed --remote origin --target-branch ${{ github.event.repository.default_branch }} --chart-dirs deployment/helm/charts)
echo "list-changed output: $changed"
if [[ -n "$changed" ]]; then
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
# lint all charts if any changes were detected
- name: Run chart-testing (lint)
# if: steps.list-changed.outputs.changed == 'true'
run: ct lint --all --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
if: steps.list-changed.outputs.changed == 'true'
run: ct lint --config ct.yaml --all
# the following would lint only changed charts, but linting isn't expensive
# run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
- name: Create kind cluster
# if: steps.list-changed.outputs.changed == 'true'
if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@v1.10.0
- name: Run chart-testing (install)
# if: steps.list-changed.outputs.changed == 'true'
run: ct install --all --config ct.yaml
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
if: steps.list-changed.outputs.changed == 'true'
run: ct install --all --helm-extra-set-args="--set=nginx.enabled=false" --debug --config ct.yaml
# the following would install only changed charts, but we only have one chart so
# don't worry about that for now
# run: ct install --target-branch ${{ github.event.repository.default_branch }}

View File

@@ -18,6 +18,11 @@ env:
# Jira
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
# Google
GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
jobs:
connectors-check:

View File

@@ -1,4 +1,5 @@
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/README.md"} -->
<a name="readme-top"></a>
<h2 align="center">
<a href="https://www.danswer.ai/"> <img width="50%" src="https://github.com/danswer-owners/danswer/blob/1fabd9372d66cd54238847197c33f091a724803b/DanswerWithName.png?raw=true)" /></a>
@@ -127,3 +128,19 @@ To try the Danswer Enterprise Edition:
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
## ⭐Star History
[![Star History Chart](https://api.star-history.com/svg?repos=danswer-ai/danswer&type=Date)](https://star-history.com/#danswer-ai/danswer&Date)
## ✨Contributors
<a href="https://github.com/aryn-ai/sycamore/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=danswer-ai/danswer"/>
</a>
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
↑ Back to Top ↑
</a>
</p>

View File

@@ -12,7 +12,6 @@ ARG DANSWER_VERSION=0.8-dev
ENV DANSWER_VERSION=${DANSWER_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true"
ARG CA_CERT_CONTENT=""
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
# Install system dependencies
@@ -39,15 +38,6 @@ RUN apt-get update && \
apt-get clean
# Conditionally write the CA certificate and update certificates
RUN if [ -n "$CA_CERT_CONTENT" ]; then \
echo "Adding custom CA certificate"; \
echo "$CA_CERT_CONTENT" > /usr/local/share/ca-certificates/my-ca.crt && \
chmod 644 /usr/local/share/ca-certificates/my-ca.crt && \
update-ca-certificates; \
else \
echo "No custom CA certificate provided"; \
fi
# Install Python dependencies
# Remove py which is pulled in by retry, py is not needed and is a CVE
@@ -87,7 +77,6 @@ RUN apt-get update && \
RUN python -c "from tokenizers import Tokenizer; \
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
# Pre-downloading NLTK for setups with limited egress
RUN python -c "import nltk; \
nltk.download('stopwords', quiet=True); \

View File

@@ -1,7 +1,7 @@
"""single tool call per message
Revision ID: 33cb72ea4d80
Revises: 949b4a92a401
Revises: 5b29123cd710
Create Date: 2024-11-01 12:51:01.535003
"""
@@ -11,67 +11,40 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "33cb72ea4d80"
down_revision = "949b4a92a401"
down_revision = "5b29123cd710"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 1. Add 'message_id' column to 'tool_call' table
op.add_column("tool_call", sa.Column("message_id", sa.Integer(), nullable=True))
# 2. Create foreign key constraint from 'tool_call.message_id' to 'chat_message.id'
op.create_foreign_key(
"fk_tool_call_message_id",
"tool_call",
"chat_message",
["message_id"],
["id"],
)
# 3. Migrate existing data from 'chat_message.tool_call_id' to 'tool_call.message_id'
# Step 1: Delete extraneous ToolCall entries
# Keep only the ToolCall with the smallest 'id' for each 'message_id'
op.execute(
sa.text(
"""
DELETE FROM tool_call
WHERE id NOT IN (
SELECT MIN(id)
FROM tool_call
WHERE message_id IS NOT NULL
GROUP BY message_id
);
"""
UPDATE tool_call
SET message_id = chat_message.id
FROM chat_message
WHERE chat_message.tool_call_id = tool_call.id
"""
)
)
# 4. Drop the foreign key constraint and column 'tool_call_id' from 'chat_message' table
op.drop_constraint("fk_chat_message_tool_call", "chat_message", type_="foreignkey")
op.drop_column("chat_message", "tool_call_id")
# 5. Optionally drop the unique constraint if it was previously added
# op.drop_constraint("uq_chat_message_tool_call_id", "chat_message", type_="unique")
# Step 2: Add a unique constraint on message_id
op.create_unique_constraint(
constraint_name="uq_tool_call_message_id",
table_name="tool_call",
columns=["message_id"],
)
def downgrade() -> None:
# 1. Add 'tool_call_id' column back to 'chat_message' table
op.add_column(
"chat_message", sa.Column("tool_call_id", sa.Integer(), nullable=True)
# Step 1: Drop the unique constraint on message_id
op.drop_constraint(
constraint_name="uq_tool_call_message_id",
table_name="tool_call",
type_="unique",
)
# 2. Restore foreign key constraint from 'chat_message.tool_call_id' to 'tool_call.id'
op.create_foreign_key(
"fk_chat_message_tool_call",
"chat_message",
"tool_call",
["tool_call_id"],
["id"],
)
# 3. Migrate data back from 'tool_call.message_id' to 'chat_message.tool_call_id'
op.execute(
"""
UPDATE chat_message
SET tool_call_id = tool_call.id
FROM tool_call
WHERE tool_call.message_id = chat_message.id
"""
)
# 4. Drop the foreign key constraint and column 'message_id' from 'tool_call' table
op.drop_constraint("fk_tool_call_message_id", "tool_call", type_="foreignkey")
op.drop_column("tool_call", "message_id")

View File

@@ -0,0 +1,70 @@
"""nullable search settings for historic index attempts
Revision ID: 5b29123cd710
Revises: 949b4a92a401
Create Date: 2024-10-30 19:37:59.630704
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "5b29123cd710"
down_revision = "949b4a92a401"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop the existing foreign key constraint
op.drop_constraint(
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
)
# Modify the column to be nullable
op.alter_column(
"index_attempt", "search_settings_id", existing_type=sa.INTEGER(), nullable=True
)
# Add back the foreign key with ON DELETE SET NULL
op.create_foreign_key(
"fk_index_attempt_search_settings",
"index_attempt",
"search_settings",
["search_settings_id"],
["id"],
ondelete="SET NULL",
)
def downgrade() -> None:
# Warning: This will delete all index attempts that don't have search settings
op.execute(
"""
DELETE FROM index_attempt
WHERE search_settings_id IS NULL
"""
)
# Drop foreign key constraint
op.drop_constraint(
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
)
# Modify the column to be not nullable
op.alter_column(
"index_attempt",
"search_settings_id",
existing_type=sa.INTEGER(),
nullable=False,
)
# Add back the foreign key without ON DELETE SET NULL
op.create_foreign_key(
"fk_index_attempt_search_settings",
"index_attempt",
"search_settings",
["search_settings_id"],
["id"],
)

View File

@@ -288,6 +288,15 @@ def upgrade() -> None:
def downgrade() -> None:
# NOTE: you will lose all chat history. This is to satisfy the non-nullable constraints
# below
op.execute("DELETE FROM chat_feedback")
op.execute("DELETE FROM chat_message__search_doc")
op.execute("DELETE FROM document_retrieval_feedback")
op.execute("DELETE FROM document_retrieval_feedback")
op.execute("DELETE FROM chat_message")
op.execute("DELETE FROM chat_session")
op.drop_constraint(
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
)

View File

@@ -0,0 +1,48 @@
"""remove description from starter messages
Revision ID: b72ed7a5db0e
Revises: 33cb72ea4d80
Create Date: 2024-11-03 15:55:28.944408
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "b72ed7a5db0e"
down_revision = "33cb72ea4d80"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
sa.text(
"""
UPDATE persona
SET starter_messages = (
SELECT jsonb_agg(elem - 'description')
FROM jsonb_array_elements(starter_messages) elem
)
WHERE starter_messages IS NOT NULL
AND jsonb_typeof(starter_messages) = 'array'
"""
)
)
def downgrade() -> None:
op.execute(
sa.text(
"""
UPDATE persona
SET starter_messages = (
SELECT jsonb_agg(elem || '{"description": ""}')
FROM jsonb_array_elements(starter_messages) elem
)
WHERE starter_messages IS NOT NULL
AND jsonb_typeof(starter_messages) = 'array'
"""
)
)

View File

@@ -0,0 +1,29 @@
"""add recent assistants
Revision ID: c0fd6e4da83a
Revises: b72ed7a5db0e
Create Date: 2024-11-03 17:28:54.916618
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "c0fd6e4da83a"
down_revision = "b72ed7a5db0e"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column(
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
),
)
def downgrade() -> None:
op.drop_column("user", "recent_assistants")

View File

@@ -23,6 +23,56 @@ def upgrade() -> None:
def downgrade() -> None:
# Delete chat messages and feedback first since they reference chat sessions
# Get chat messages from sessions with null persona_id
chat_messages_query = """
SELECT id
FROM chat_message
WHERE chat_session_id IN (
SELECT id
FROM chat_session
WHERE persona_id IS NULL
)
"""
# Delete dependent records first
op.execute(
f"""
DELETE FROM document_retrieval_feedback
WHERE chat_message_id IN (
{chat_messages_query}
)
"""
)
op.execute(
f"""
DELETE FROM chat_message__search_doc
WHERE chat_message_id IN (
{chat_messages_query}
)
"""
)
# Delete chat messages
op.execute(
"""
DELETE FROM chat_message
WHERE chat_session_id IN (
SELECT id
FROM chat_session
WHERE persona_id IS NULL
)
"""
)
# Now we can safely delete the chat sessions
op.execute(
"""
DELETE FROM chat_session
WHERE persona_id IS NULL
"""
)
op.alter_column(
"chat_session",
"persona_id",

View File

@@ -1,12 +1,14 @@
import secrets
import uuid
from urllib.parse import quote
from urllib.parse import unquote
from fastapi import Request
from passlib.hash import sha256_crypt
from pydantic import BaseModel
from danswer.auth.schemas import UserRole
from ee.danswer.configs.app_configs import API_KEY_HASH_ROUNDS
from danswer.configs.app_configs import API_KEY_HASH_ROUNDS
_API_KEY_HEADER_NAME = "Authorization"
@@ -30,8 +32,35 @@ class ApiKeyDescriptor(BaseModel):
user_id: uuid.UUID
def generate_api_key() -> str:
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
def generate_api_key(tenant_id: str | None = None) -> str:
# For backwards compatibility, if no tenant_id, generate old style key
if not tenant_id:
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
encoded_tenant = quote(tenant_id) # URL encode the tenant ID
return f"{_API_KEY_PREFIX}{encoded_tenant}.{secrets.token_urlsafe(_API_KEY_LEN)}"
def extract_tenant_from_api_key_header(request: Request) -> str | None:
"""Extract tenant ID from request. Returns None if auth is disabled or invalid format."""
raw_api_key_header = request.headers.get(
_API_KEY_HEADER_ALTERNATIVE_NAME
) or request.headers.get(_API_KEY_HEADER_NAME)
if not raw_api_key_header or not raw_api_key_header.startswith(_BEARER_PREFIX):
return None
api_key = raw_api_key_header[len(_BEARER_PREFIX) :].strip()
if not api_key.startswith(_API_KEY_PREFIX):
return None
parts = api_key[len(_API_KEY_PREFIX) :].split(".", 1)
if len(parts) != 2:
return None
tenant_id = parts[0]
return unquote(tenant_id) if tenant_id else None
def hash_api_key(api_key: str) -> str:

View File

@@ -48,11 +48,11 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.orm import attributes
from sqlalchemy.orm import Session
from danswer.auth.api_key import get_hashed_api_key_from_request
from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
@@ -75,6 +75,7 @@ from danswer.configs.constants import AuthType
from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
from danswer.db.api_key import fetch_user_for_api_key
from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
@@ -83,24 +84,27 @@ from danswer.db.auth import SQLAlchemyUserAdminDB
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.engine import get_session
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User
from danswer.db.models import UserTenantMapping
from danswer.db.users import get_user_by_email
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import async_return_default_schema
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
class BasicAuthenticationError(HTTPException):
def __init__(self, detail: str):
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
def is_user_admin(user: User | None) -> bool:
if AUTH_TYPE == AuthType.DISABLED:
return True
@@ -190,20 +194,6 @@ def verify_email_domain(email: str) -> None:
)
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
# Implement logic to get tenant_id from the mapping table
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
)
tenant_id = result.scalar_one_or_none()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id
def send_user_verification_email(
user_email: str,
token: str,
@@ -238,19 +228,13 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False,
request: Optional[Request] = None,
) -> User:
try:
tenant_id = (
get_tenant_id_for_email(user_create.email)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
if not tenant_id:
raise HTTPException(
status_code=401, detail="User does not belong to an organization"
)
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=user_create.email,
)
async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -271,7 +255,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
user = None
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
@@ -292,7 +276,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
else:
raise exceptions.UserAlreadyExists()
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return user
async def oauth_callback(
@@ -308,19 +294,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> models.UOAP:
# Get tenant_id from mapping table
try:
tenant_id = (
get_tenant_id_for_email(account_email)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=account_email,
)
if not tenant_id:
raise HTTPException(status_code=401, detail="User not found")
# Proceed with the tenant context
token = None
async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -371,9 +356,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# Explicitly set the Postgres schema for this session to ensure
# OAuth account creation happens in the correct tenant schema
await db_session.execute(text(f'SET search_path = "{tenant_id}"'))
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
# Add OAuth account
await self.user_db.add_oauth_account(user, oauth_account_dict)
await self.on_after_register(user, request)
else:
@@ -453,7 +438,13 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
email = credentials.username
# Get tenant_id from mapping table
tenant_id = get_tenant_id_for_email(email)
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=email,
)
if not tenant_id:
# User not found in mapping
self.password_helper.hash(credentials.password)
@@ -477,8 +468,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
has_web_login = attributes.get_attribute(user, "has_web_login")
if not has_web_login:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
raise BasicAuthenticationError(
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
)
@@ -510,19 +500,30 @@ cookie_transport = CookieTransport(
# This strategy is used to add tenant_id to the JWT token
class TenantAwareJWTStrategy(JWTStrategy):
async def write_token(self, user: User) -> str:
tenant_id = get_tenant_id_for_email(user.email)
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=user.email,
)
data = {
"sub": str(user.id),
"aud": self.token_audience,
"tenant_id": tenant_id,
}
return data
async def write_token(self, user: User) -> str:
data = await self._create_token_data(user)
return generate_jwt(
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
)
def get_jwt_strategy() -> JWTStrategy:
def get_jwt_strategy() -> TenantAwareJWTStrategy:
return TenantAwareJWTStrategy(
secret=USER_AUTH_SECRET,
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
@@ -624,14 +625,12 @@ async def double_check_user(
return None
if user is None:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
raise BasicAuthenticationError(
detail="Access denied. User is not authenticated.",
)
if user_needs_to_be_verified() and not user.is_verified:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
raise BasicAuthenticationError(
detail="Access denied. User is not verified.",
)
@@ -640,8 +639,7 @@ async def double_check_user(
and user.oidc_expiry < datetime.now(timezone.utc)
and not include_expired
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
raise BasicAuthenticationError(
detail="Access denied. User's OIDC token has expired.",
)
@@ -667,15 +665,13 @@ async def current_curator_or_admin_user(
return None
if not user or not hasattr(user, "role"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
raise BasicAuthenticationError(
detail="Access denied. User is not authenticated or lacks role information.",
)
allowed_roles = {UserRole.GLOBAL_CURATOR, UserRole.CURATOR, UserRole.ADMIN}
if user.role not in allowed_roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
raise BasicAuthenticationError(
detail="Access denied. User is not a curator or admin.",
)
@@ -687,8 +683,7 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
return None
if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
raise BasicAuthenticationError(
detail="Access denied. User must be an admin to perform this action.",
)
@@ -881,3 +876,22 @@ def get_oauth_router(
return redirect_response
return router
def api_key_dep(
request: Request, db_session: Session = Depends(get_session)
) -> User | None:
if AUTH_TYPE == AuthType.DISABLED:
return None
hashed_api_key = get_hashed_api_key_from_request(request)
if not hashed_api_key:
raise HTTPException(status_code=401, detail="Missing API key")
if hashed_api_key:
user = fetch_user_for_api_key(hashed_api_key, db_session)
if user is None:
raise HTTPException(status_code=401, detail="Invalid API key")
return user

View File

@@ -3,6 +3,7 @@ import multiprocessing
import time
from typing import Any
import requests
import sentry_sdk
from celery import Task
from celery.app import trace
@@ -11,18 +12,22 @@ from celery.states import READY_STATES
from celery.utils.log import get_task_logger
from celery.worker import strategy # type: ignore
from sentry_sdk.integrations.celery import CeleryIntegration
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
from danswer.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import get_sqlalchemy_engine
from danswer.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_document_set import RedisDocumentSet
from danswer.redis.redis_pool import get_redis_client
from danswer.redis.redis_usergroup import RedisUserGroup
from danswer.utils.logger import ColoredFormatter
from danswer.utils.logger import PlainFormatter
from danswer.utils.logger import setup_logger
@@ -108,29 +113,27 @@ def on_task_postrun(
if task_id.startswith(RedisDocumentSet.PREFIX):
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
if document_set_id is not None:
rds = RedisDocumentSet(int(document_set_id))
rds = RedisDocumentSet(tenant_id, int(document_set_id))
r.srem(rds.taskset_key, task_id)
return
if task_id.startswith(RedisUserGroup.PREFIX):
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
if usergroup_id is not None:
rug = RedisUserGroup(int(usergroup_id))
rug = RedisUserGroup(tenant_id, int(usergroup_id))
r.srem(rug.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorDeletion.PREFIX):
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
if task_id.startswith(RedisConnectorDelete.PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if cc_pair_id is not None:
rcd = RedisConnectorDeletion(int(cc_pair_id))
r.srem(rcd.taskset_key, task_id)
RedisConnectorDelete.remove_from_taskset(int(cc_pair_id), task_id, r)
return
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
if task_id.startswith(RedisConnectorPrune.SUBTASK_PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if cc_pair_id is not None:
rcp = RedisConnectorPruning(int(cc_pair_id))
r.srem(rcp.taskset_key, task_id)
RedisConnectorPrune.remove_from_taskset(int(cc_pair_id), task_id, r)
return
@@ -140,77 +143,154 @@ def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
"""Waits for redis to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
r = get_redis_client(tenant_id=None)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Redis: Readiness check starting.")
logger.info("Redis: Readiness probe starting.")
while True:
try:
if r.ping():
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
logger.info(
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Redis: Readiness check did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
break
logger.info(
f"Redis: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
time.sleep(WAIT_INTERVAL)
logger.info("Redis: Readiness check succeeded. Continuing...")
if not ready:
msg = (
f"Redis: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Redis: Readiness probe succeeded. Continuing...")
return
def wait_for_db(sender: Any, **kwargs: Any) -> None:
"""Waits for the db to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Database: Readiness probe starting.")
while True:
try:
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(text("SELECT NOW()")).scalar()
if result:
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > WAIT_LIMIT:
break
logger.info(
f"Database: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
time.sleep(WAIT_INTERVAL)
if not ready:
msg = (
f"Database: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Database: Readiness probe succeeded. Continuing...")
return
def wait_for_vespa(sender: Any, **kwargs: Any) -> None:
"""Waits for Vespa to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Vespa: Readiness probe starting.")
while True:
try:
response = requests.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
response.raise_for_status()
response_dict = response.json()
if response_dict["status"]["code"] == "up":
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > WAIT_LIMIT:
break
logger.info(
f"Vespa: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
time.sleep(WAIT_INTERVAL)
if not ready:
msg = (
f"Vespa: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Vespa: Readiness probe succeeded. Continuing...")
return
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("Running as a secondary celery worker.")
# Set up variables for waiting on primary worker
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
logger.info("Running as a secondary celery worker.")
logger.info("Waiting for all tenant primary workers to be ready...")
r = get_redis_client(tenant_id=None)
time_start = time.monotonic()
logger.info("Waiting for primary worker to be ready...")
while True:
tenant_ids = get_all_tenant_ids()
# Check if we have a primary worker lock for each tenant
all_tenants_ready = all(
get_redis_client(tenant_id=tenant_id).exists(
DanswerRedisLocks.PRIMARY_WORKER
)
for tenant_id in tenant_ids
)
if all_tenants_ready:
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
break
time_elapsed = time.monotonic() - time_start
ready_tenants = sum(
1
for tenant_id in tenant_ids
if get_redis_client(tenant_id=tenant_id).exists(
DanswerRedisLocks.PRIMARY_WORKER
)
)
logger.info(
f"Not all tenant primary workers are ready yet. "
f"Ready tenants: {ready_tenants}/{len(tenant_ids)} "
f"elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Not all tenant primary workers were ready within the timeout "
f"Primary worker was not ready within the timeout. "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
@@ -218,7 +298,7 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
time.sleep(WAIT_INTERVAL)
logger.info("All tenant primary workers are ready. Continuing...")
logger.info("Wait for primary worker completed successfully. Continuing...")
return
@@ -230,26 +310,20 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
if not celery_is_worker_primary(sender):
return
if not hasattr(sender, "primary_worker_locks"):
if not sender.primary_worker_lock:
return
for tenant_id, lock in sender.primary_worker_locks.items():
try:
if lock and lock.owned():
logger.debug(f"Attempting to release lock for tenant {tenant_id}")
try:
lock.release()
logger.debug(f"Successfully released lock for tenant {tenant_id}")
except Exception as e:
logger.error(
f"Failed to release lock for tenant {tenant_id}. Error: {str(e)}"
)
finally:
sender.primary_worker_locks[tenant_id] = None
except Exception as e:
logger.error(
f"Error checking lock status for tenant {tenant_id}. Error: {str(e)}"
)
logger.info("Releasing primary worker lock.")
lock = sender.primary_worker_lock
try:
if lock.owned():
try:
lock.release()
sender.primary_worker_lock = None
except Exception as e:
logger.error(f"Failed to release primary worker lock: {e}")
except Exception as e:
logger.error(f"Failed to check if primary worker lock is owned: {e}")
def on_setup_logging(

View File

@@ -3,28 +3,152 @@ from typing import Any
from celery import Celery
from celery import signals
from celery.beat import PersistentScheduler # type: ignore
from celery.signals import beat_init
import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
logger = setup_logger(__name__)
celery_app = Celery(__name__)
celery_app.config_from_object("danswer.background.celery.configs.beat")
class DynamicTenantScheduler(PersistentScheduler):
def __init__(self, *args: Any, **kwargs: Any) -> None:
logger.info("Initializing DynamicTenantScheduler")
super().__init__(*args, **kwargs)
self._reload_interval = timedelta(minutes=2)
self._last_reload = self.app.now() - self._reload_interval
# Let the parent class handle store initialization
self.setup_schedule()
self._update_tenant_tasks()
logger.info(f"Set reload interval to {self._reload_interval}")
def setup_schedule(self) -> None:
logger.info("Setting up initial schedule")
super().setup_schedule()
logger.info("Initial schedule setup complete")
def tick(self) -> float:
retval = super().tick()
now = self.app.now()
if (
self._last_reload is None
or (now - self._last_reload) > self._reload_interval
):
logger.info("Reload interval reached, initiating tenant task update")
self._update_tenant_tasks()
self._last_reload = now
logger.info("Tenant task update completed, reset reload timer")
return retval
def _update_tenant_tasks(self) -> None:
logger.info("Starting tenant task update process")
try:
logger.info("Fetching all tenant IDs")
tenant_ids = get_all_tenant_ids()
logger.info(f"Found {len(tenant_ids)} tenants")
logger.info("Fetching tasks to schedule")
tasks_to_schedule = fetch_versioned_implementation(
"danswer.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
)
new_beat_schedule: dict[str, dict[str, Any]] = {}
current_schedule = self.schedule.items()
existing_tenants = set()
for task_name, _ in current_schedule:
if "-" in task_name:
existing_tenants.add(task_name.split("-")[-1])
logger.info(f"Found {len(existing_tenants)} existing tenants in schedule")
for tenant_id in tenant_ids:
if tenant_id not in existing_tenants:
logger.info(f"Processing new tenant: {tenant_id}")
for task in tasks_to_schedule():
task_name = f"{task['name']}-{tenant_id}"
logger.debug(f"Creating task configuration for {task_name}")
new_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": {"tenant_id": tenant_id},
}
if options := task.get("options"):
logger.debug(f"Adding options to task {task_name}: {options}")
new_task["options"] = options
new_beat_schedule[task_name] = new_task
if self._should_update_schedule(current_schedule, new_beat_schedule):
logger.info(
"Schedule update required",
extra={
"new_tasks": len(new_beat_schedule),
"current_tasks": len(current_schedule),
},
)
# Create schedule entries
entries = {}
for name, entry in new_beat_schedule.items():
entries[name] = self.Entry(
name=name,
app=self.app,
task=entry["task"],
schedule=entry["schedule"],
options=entry.get("options", {}),
kwargs=entry.get("kwargs", {}),
)
# Update the schedule using the scheduler's methods
self.schedule.clear()
self.schedule.update(entries)
# Ensure changes are persisted
self.sync()
logger.info("Schedule update completed successfully")
else:
logger.info("Schedule is up to date, no changes needed")
except (AttributeError, KeyError):
logger.exception("Failed to process task configuration")
except Exception:
logger.exception("Unexpected error updating tenant tasks")
def _should_update_schedule(
self, current_schedule: dict, new_schedule: dict
) -> bool:
"""Compare schedules to determine if an update is needed."""
logger.debug("Comparing current and new schedules")
current_tasks = set(name for name, _ in current_schedule)
new_tasks = set(new_schedule.keys())
needs_update = current_tasks != new_tasks
logger.debug(f"Schedule update needed: {needs_update}")
return needs_update
@beat_init.connect
def on_beat_init(sender: Any, **kwargs: Any) -> None:
logger.info("beat_init signal received.")
# celery beat shouldn't touch the db at all. But just setting a low minimum here.
# Celery beat shouldn't touch the db at all. But just setting a low minimum here.
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
SqlEngine.init_engine(pool_size=2, max_overflow=0)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
@@ -35,68 +159,4 @@ def on_setup_logging(
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
#####
# Celery Beat (Periodic Tasks) Settings
#####
tenant_ids = get_all_tenant_ids()
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": "check_for_vespa_sync_task",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-connector-deletion",
"task": "check_for_connector_deletion_task",
"schedule": timedelta(seconds=60),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-indexing",
"task": "check_for_indexing",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": "check_for_pruning",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "kombu-message-cleanup",
"task": "kombu_message_cleanup_task",
"schedule": timedelta(seconds=3600),
"options": {"priority": DanswerCeleryPriority.LOWEST},
},
{
"name": "monitor-vespa-sync",
"task": "monitor_vespa_sync",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
]
# Build the celery beat schedule dynamically
beat_schedule = {}
for tenant_id in tenant_ids:
for task in tasks_to_schedule:
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
beat_schedule[task_name] = {
"task": task["task"],
"schedule": task["schedule"],
"options": task["options"],
"kwargs": {"tenant_id": tenant_id}, # Must pass tenant_id as an argument
}
# Include any existing beat schedules
existing_beat_schedule = celery_app.conf.beat_schedule or {}
beat_schedule.update(existing_beat_schedule)
# Update the Celery app configuration once
celery_app.conf.beat_schedule = beat_schedule
celery_app.conf.beat_scheduler = DynamicTenantScheduler

View File

@@ -13,6 +13,7 @@ import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -60,7 +61,13 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=4, max_overflow=12)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.on_secondary_worker_init(sender, **kwargs)

View File

@@ -13,6 +13,7 @@ import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -60,7 +61,13 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.on_secondary_worker_init(sender, **kwargs)

View File

@@ -13,6 +13,7 @@ import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -59,8 +60,13 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.on_secondary_worker_init(sender, **kwargs)
@@ -85,5 +91,6 @@ celery_app.autodiscover_tasks(
[
"danswer.background.celery.tasks.shared",
"danswer.background.celery.tasks.vespa",
"danswer.background.celery.tasks.connector_deletion",
]
)

View File

@@ -13,21 +13,21 @@ from celery.signals import worker_shutdown
import danswer.background.celery.apps.app_base as app_base
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import SqlEngine
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_connector_stop import RedisConnectorStop
from danswer.redis.redis_document_set import RedisDocumentSet
from danswer.redis.redis_pool import get_redis_client
from danswer.redis.redis_usergroup import RedisUserGroup
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -75,95 +75,64 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
logger.info("Running as the primary celery worker.")
sender.primary_worker_locks = {}
# This is singleton work that should be done on startup exactly once
# by the primary worker
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
r = get_redis_client(tenant_id=tenant_id)
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_redis_client(tenant_id=None)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
# this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
# this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")
logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")
# tacking on our own user data to the sender
sender.primary_worker_locks[tenant_id] = lock
# tacking on our own user data to the sender
sender.primary_worker_lock = lock
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
r.delete(key)
RedisDocumentSet.reset_all(r)
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
r.delete(key)
RedisUserGroup.reset_all(r)
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
r.delete(key)
RedisConnectorDelete.reset_all(r)
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
r.delete(key)
RedisConnectorPrune.reset_all(r)
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
r.delete(key)
RedisConnectorIndex.reset_all(r)
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"):
r.delete(key)
RedisConnectorStop.reset_all(r)
@worker_ready.connect
@@ -216,52 +185,36 @@ class HubPeriodicTask(bootsteps.StartStopStep):
if not celery_is_worker_primary(worker):
return
if not hasattr(worker, "primary_worker_locks"):
if not hasattr(worker, "primary_worker_lock"):
return
# Retrieve all tenant IDs
tenant_ids = get_all_tenant_ids()
lock = worker.primary_worker_lock
for tenant_id in tenant_ids:
lock = worker.primary_worker_locks.get(tenant_id)
if not lock:
continue # Skip if no lock for this tenant
r = get_redis_client(tenant_id=None)
r = get_redis_client(tenant_id=tenant_id)
if lock.owned():
task_logger.debug("Reacquiring primary worker lock.")
lock.reacquire()
else:
task_logger.warning(
"Full acquisition of primary worker lock. "
"Reasons could be worker restart or lock expiration."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
if lock.owned():
task_logger.debug(
f"Reacquiring primary worker lock for tenant {tenant_id}."
)
lock.reacquire()
task_logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info("Primary worker lock: Acquire succeeded.")
worker.primary_worker_lock = lock
else:
task_logger.warning(
f"Full acquisition of primary worker lock for tenant {tenant_id}. "
"Reasons could be worker restart or lock expiration."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire starting."
)
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire succeeded."
)
worker.primary_worker_locks[tenant_id] = lock
else:
task_logger.error(
f"Primary worker lock for tenant {tenant_id}: Acquire failed!"
)
raise TimeoutError(
f"Primary worker lock for tenant {tenant_id} could not be acquired!"
)
task_logger.error("Primary worker lock: Acquire failed!")
raise TimeoutError("Primary worker lock could not be acquired!")
except Exception:
task_logger.exception("Periodic task failed.")

View File

@@ -0,0 +1,96 @@
from datetime import timedelta
from typing import Any
from celery.beat import PersistentScheduler # type: ignore
from celery.utils.log import get_task_logger
from danswer.db.engine import get_all_tenant_ids
from danswer.utils.variable_functionality import fetch_versioned_implementation
logger = get_task_logger(__name__)
class DynamicTenantScheduler(PersistentScheduler):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._reload_interval = timedelta(minutes=1)
self._last_reload = self.app.now() - self._reload_interval
def setup_schedule(self) -> None:
super().setup_schedule()
def tick(self) -> float:
retval = super().tick()
now = self.app.now()
if (
self._last_reload is None
or (now - self._last_reload) > self._reload_interval
):
logger.info("Reloading schedule to check for new tenants...")
self._update_tenant_tasks()
self._last_reload = now
return retval
def _update_tenant_tasks(self) -> None:
logger.info("Checking for tenant task updates...")
try:
tenant_ids = get_all_tenant_ids()
tasks_to_schedule = fetch_versioned_implementation(
"danswer.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
)
new_beat_schedule: dict[str, dict[str, Any]] = {}
current_schedule = getattr(self, "_store", {"entries": {}}).get(
"entries", {}
)
existing_tenants = set()
for task_name in current_schedule.keys():
if "-" in task_name:
existing_tenants.add(task_name.split("-")[-1])
for tenant_id in tenant_ids:
if tenant_id not in existing_tenants:
logger.info(f"Found new tenant: {tenant_id}")
for task in tasks_to_schedule():
task_name = f"{task['name']}-{tenant_id}"
new_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": {"tenant_id": tenant_id},
}
if options := task.get("options"):
new_task["options"] = options
new_beat_schedule[task_name] = new_task
if self._should_update_schedule(current_schedule, new_beat_schedule):
logger.info(
"Updating schedule",
extra={
"new_tasks": len(new_beat_schedule),
"current_tasks": len(current_schedule),
},
)
if not hasattr(self, "_store"):
self._store: dict[str, dict] = {"entries": {}}
self.update_from_dict(new_beat_schedule)
logger.info(f"New schedule: {new_beat_schedule}")
logger.info("Tenant tasks updated successfully")
else:
logger.debug("No schedule updates needed")
except (AttributeError, KeyError):
logger.exception("Failed to process task configuration")
except Exception:
logger.exception("Unexpected error updating tenant tasks")
def _should_update_schedule(
self, current_schedule: dict, new_schedule: dict
) -> bool:
"""Compare schedules to determine if an update is needed."""
current_tasks = set(current_schedule.keys())
new_tasks = set(new_schedule.keys())
return current_tasks != new_tasks

View File

@@ -1,568 +1,10 @@
# These are helper objects for tracking the keys we need to write in redis
import time
from abc import ABC
from abc import abstractmethod
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.configs.base import CELERY_SEPARATOR
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import construct_document_select_for_connector_credential_pair
from danswer.db.document import (
construct_document_select_for_connector_credential_pair_by_needs_sync,
)
from danswer.db.document_set import construct_document_select_by_docset
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import global_version
class RedisObjectHelper(ABC):
PREFIX = "base"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: str):
self._id: str = id
@property
def task_id_prefix(self) -> str:
return f"{self.PREFIX}_{self._id}"
@property
def fence_key(self) -> str:
# example: documentset_fence_1
return f"{self.FENCE_PREFIX}_{self._id}"
@property
def taskset_key(self) -> str:
# example: documentset_taskset_1
return f"{self.TASKSET_PREFIX}_{self._id}"
@staticmethod
def get_id_from_fence_key(key: str) -> str | None:
"""
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
Args:
key (str): The fence key string.
Returns:
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
"""
parts = key.split("_")
if len(parts) != 3:
return None
object_id = parts[2]
return object_id
@staticmethod
def get_id_from_task_id(task_id: str) -> str | None:
"""
Extracts the object ID from a task ID string.
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
- `objectid` is the ID you want to extract,
- `suffix` is another arbitrary string (e.g., a UUID).
Example:
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
this method will return the string `"1"`.
Args:
task_id (str): The task ID string from which to extract the object ID.
Returns:
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
"""
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
parts = task_id.split("_")
if len(parts) != 3:
return None
object_id = parts[1]
return object_id
@abstractmethod
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
pass
class RedisDocumentSet(RedisObjectHelper):
PREFIX = "documentset"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
stmt = construct_document_select_by_docset(int(self._id), current_only=False)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
async_results.append(result)
return len(async_results)
class RedisUserGroup(RedisObjectHelper):
PREFIX = "usergroup"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
if not global_version.is_ee_version():
return 0
try:
construct_document_select_by_usergroup = fetch_versioned_implementation(
"danswer.db.user_group",
"construct_document_select_by_usergroup",
)
except ModuleNotFoundError:
return 0
stmt = construct_document_select_by_usergroup(int(self._id))
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
async_results.append(result)
return len(async_results)
class RedisConnectorCredentialPair(RedisObjectHelper):
"""This class is used to scan documents by cc_pair in the db and collect them into
a unified set for syncing.
It differs from the other redis helpers in that the taskset used spans
all connectors and is not per connector."""
PREFIX = "connectorsync"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
@classmethod
def get_fence_key(cls) -> str:
return RedisConnectorCredentialPair.FENCE_PREFIX
@classmethod
def get_taskset_key(cls) -> str:
return RedisConnectorCredentialPair.TASKSET_PREFIX
@property
def taskset_key(self) -> str:
"""Notice that this is intentionally reusing the same taskset for all
connector syncs"""
# example: connector_taskset
return f"{self.TASKSET_PREFIX}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
class RedisConnectorDeletion(RedisObjectHelper):
PREFIX = "connectordeletion"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
"""Returns None if the cc_pair doesn't exist.
Otherwise, returns an int with the number of generated tasks."""
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
stmt = construct_document_select_for_connector_credential_pair(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(self.taskset_key, custom_task_id)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"document_by_cc_pair_cleanup_task",
kwargs=dict(
document_id=doc.id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
class RedisConnectorPruning(RedisObjectHelper):
"""Celery will kick off a long running generator task to crawl the connector and
find any missing docs, which will each then get a new cleanup task. The progress of
those tasks will then be monitored to completion.
Example rough happy path order:
Check connectorpruning_fence_1
Send generator task with id connectorpruning+generator_1_{uuid}
generator runs connector with callbacks that increment connectorpruning_generator_progress_1
generator creates many subtasks with id connectorpruning+sub_1_{uuid}
in taskset connectorpruning_taskset_1
on completion, generator sets connectorpruning_generator_complete_1
celery postrun removes subtasks from taskset
monitor beat task cleans up when taskset reaches 0 items
"""
PREFIX = "connectorpruning"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire pruning process
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
SUBTASK_PREFIX = PREFIX + "+sub"
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # a signal that contains generator progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # a signal that the generator has finished
def __init__(self, id: int) -> None:
super().__init__(str(id))
self.documents_to_prune: set[str] = set()
@property
def generator_task_id_prefix(self) -> str:
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
@property
def generator_progress_key(self) -> str:
# example: connectorpruning_generator_progress_1
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
@property
def generator_complete_key(self) -> str:
# example: connectorpruning_generator_complete_1
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
@property
def subtask_id_prefix(self) -> str:
return f"{self.SUBTASK_PREFIX}_{self._id}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
for doc_id in self.documents_to_prune:
current_time = time.monotonic()
if lock and current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.subtask_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(self.taskset_key, custom_task_id)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"document_by_cc_pair_cleanup_task",
kwargs=dict(
document_id=doc_id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
def is_pruning(self, redis_client: Redis) -> bool:
"""A single example of a helper method being refactored into the redis helper"""
if redis_client.exists(self.fence_key):
return True
return False
class RedisConnectorIndexing(RedisObjectHelper):
"""Celery will kick off a long running indexing task to crawl the connector and
find any new or updated docs docs, which will each then get a new sync task or be
indexed inline.
ID should be a concatenation of cc_pair_id and search_setting_id, delimited by "/".
e.g. "2/5"
"""
PREFIX = "connectorindexing"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
SUBTASK_PREFIX = PREFIX + "+sub"
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # a signal that contains generator progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # a signal that the generator has finished
def __init__(self, cc_pair_id: int, search_settings_id: int) -> None:
super().__init__(f"{cc_pair_id}/{search_settings_id}")
@property
def generator_lock_key(self) -> str:
return f"{self.GENERATOR_LOCK_PREFIX}_{self._id}"
@property
def generator_task_id_prefix(self) -> str:
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
@property
def generator_progress_key(self) -> str:
# example: connectorpruning_generator_progress_1
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
@property
def generator_complete_key(self) -> str:
# example: connectorpruning_generator_complete_1
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
@property
def subtask_id_prefix(self) -> str:
return f"{self.SUBTASK_PREFIX}_{self._id}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
return None
def is_indexing(self, redis_client: Redis) -> bool:
"""A single example of a helper method being refactored into the redis helper"""
if redis_client.exists(self.fence_key):
return True
return False
class RedisConnectorStop(RedisObjectHelper):
"""Used to signal any running tasks for a connector to stop. We should refactor
connector related redis helpers into a single class.
"""
PREFIX = "connectorstop"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
return None
def celery_get_queue_length(queue: str, r: Redis) -> int:

View File

@@ -4,7 +4,6 @@ from typing import Any
from sqlalchemy.orm import Session
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
@@ -18,7 +17,7 @@ from danswer.connectors.models import Document
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.enums import TaskStatus
from danswer.db.models import TaskQueueState
from danswer.redis.redis_pool import get_redis_client
from danswer.redis.redis_connector import RedisConnector
from danswer.server.documents.models import DeletionAttemptSnapshot
from danswer.utils.logger import setup_logger
@@ -41,14 +40,14 @@ def _get_deletion_status(
if not cc_pair:
return None
rcd = RedisConnectorDeletion(cc_pair.id)
r = get_redis_client(tenant_id=tenant_id)
if not r.exists(rcd.fence_key):
redis_connector = RedisConnector(tenant_id, cc_pair.id)
if not redis_connector.delete.fenced:
return None
return TaskQueueState(
task_id="", task_name=rcd.fence_key, status=TaskStatus.STARTED
task_id="",
task_name=redis_connector.delete.fence_key,
status=TaskStatus.STARTED,
)

View File

@@ -0,0 +1,48 @@
from datetime import timedelta
from typing import Any
from danswer.configs.constants import DanswerCeleryPriority
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": "check_for_vespa_sync_task",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-connector-deletion",
"task": "check_for_connector_deletion_task",
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-indexing",
"task": "check_for_indexing",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": "check_for_pruning",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "kombu-message-cleanup",
"task": "kombu_message_cleanup_task",
"schedule": timedelta(seconds=3600),
"options": {"priority": DanswerCeleryPriority.LOWEST},
},
{
"name": "monitor-vespa-sync",
"task": "monitor_vespa_sync",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
]
def get_tasks_to_schedule() -> list[dict[str, Any]]:
return tasks_to_schedule

View File

@@ -10,13 +10,6 @@ from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
RedisConnectorDeletionFenceData,
)
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
@@ -25,6 +18,8 @@ from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.search_settings import get_all_search_settings
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_delete import RedisConnectorDeletionFenceData
from danswer.redis.redis_pool import get_redis_client
@@ -62,7 +57,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
# try running cleanup on the cc_pair_ids
for cc_pair_id in cc_pair_ids:
with get_session_with_tenant(tenant_id) as db_session:
rcs = RedisConnectorStop(cc_pair_id)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
try:
try_generate_document_cc_pair_cleanup_tasks(
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
@@ -71,10 +66,10 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
# this means we wanted to start deleting but dependent tasks were running
# Leave a stop signal to clear indexing and pruning tasks more quickly
task_logger.info(str(e))
r.set(rcs.fence_key, cc_pair_id)
redis_connector.stop.set_fence(True)
else:
# clear the stop signal if it exists ... no longer needed
r.delete(rcs.fence_key)
redis_connector.stop.set_fence(False)
except SoftTimeLimitExceeded:
task_logger.info(
@@ -106,10 +101,10 @@ def try_generate_document_cc_pair_cleanup_tasks(
lock_beat.reacquire()
rcd = RedisConnectorDeletion(cc_pair_id)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
# don't generate sync tasks if tasks are still pending
if r.exists(rcd.fence_key):
if redis_connector.delete.fenced:
return None
# we need to load the state of the object inside the fence
@@ -123,47 +118,49 @@ def try_generate_document_cc_pair_cleanup_tasks(
return None
# set a basic fence to start
fence_value = RedisConnectorDeletionFenceData(
fence_payload = RedisConnectorDeletionFenceData(
num_tasks=None,
submitted=datetime.now(timezone.utc),
)
r.set(rcd.fence_key, fence_value.model_dump_json())
redis_connector.delete.set_fence(fence_payload)
try:
# do not proceed if connector indexing or connector pruning are running
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
rci = RedisConnectorIndexing(cc_pair_id, search_settings.id)
if r.get(rci.fence_key):
redis_connector_index = redis_connector.new_index(search_settings.id)
if redis_connector_index.fenced:
raise TaskDependencyError(
f"Connector deletion - Delayed (indexing in progress): "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings.id}"
)
rcp = RedisConnectorPruning(cc_pair_id)
if r.get(rcp.fence_key):
if redis_connector.prune.fenced:
raise TaskDependencyError(
f"Connector deletion - Delayed (pruning in progress): "
f"cc_pair={cc_pair_id}"
)
# add tasks to celery and build up the task set to monitor in redis
r.delete(rcd.taskset_key)
redis_connector.delete.taskset_clear()
# Add all documents that need to be updated into the queue
task_logger.info(
f"RedisConnectorDeletion.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id)
tasks_generated = redis_connector.delete.generate_tasks(
app, db_session, lock_beat
)
if tasks_generated is None:
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
except TaskDependencyError:
r.delete(rcd.fence_key)
redis_connector.delete.set_fence(None)
raise
except Exception:
task_logger.exception("Unexpected exception")
r.delete(rcd.fence_key)
redis_connector.delete.set_fence(None)
return None
else:
# Currently we are allowing the sync to proceed with 0 tasks.
@@ -178,7 +175,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
)
# set this only after all tasks have been added
fence_value.num_tasks = tasks_generated
r.set(rcd.fence_key, fence_value.model_dump_json())
fence_payload.num_tasks = tasks_generated
redis_connector.delete.set_fence(fence_payload)
return tasks_generated

View File

@@ -2,10 +2,9 @@ from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from time import sleep
from typing import cast
from uuid import uuid4
import redis
import sentry_sdk
from celery import Celery
from celery import shared_task
from celery import Task
@@ -14,12 +13,6 @@ from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
RedisConnectorIndexingFenceData,
)
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
@@ -50,12 +43,15 @@ from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_index import RedisConnectorIndexingFenceData
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
@@ -129,6 +125,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
cc_pair_ids.append(cc_pair_entry.id)
for cc_pair_id in cc_pair_ids:
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
@@ -141,10 +138,10 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
search_settings.append(secondary_search_settings)
for search_settings_instance in search_settings:
rci = RedisConnectorIndexing(
cc_pair_id, search_settings_instance.id
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
)
if r.exists(rci.fence_key):
if redis_connector_index.fenced:
continue
cc_pair = get_connector_credential_pair_from_id(
@@ -178,7 +175,9 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
)
if attempt_id:
task_logger.info(
f"Indexing queued: cc_pair={cc_pair.id} index_attempt={attempt_id}"
f"Indexing queued: index_attempt={attempt_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id} "
)
tasks_created += 1
except SoftTimeLimitExceeded:
@@ -307,15 +306,15 @@ def try_creating_indexing_task(
return None
try:
rci = RedisConnectorIndexing(cc_pair.id, search_settings.id)
redis_connector = RedisConnector(tenant_id, cc_pair.id)
redis_connector_index = redis_connector.new_index(search_settings.id)
# skip if already indexing
if r.exists(rci.fence_key):
if redis_connector_index.fenced:
return None
# skip indexing if the cc_pair is deleting
rcd = RedisConnectorDeletion(cc_pair.id)
if r.exists(rcd.fence_key):
if redis_connector.delete.fenced:
return None
db_session.refresh(cc_pair)
@@ -323,19 +322,17 @@ def try_creating_indexing_task(
return None
# add a long running generator task to the queue
r.delete(rci.generator_complete_key)
r.delete(rci.taskset_key)
custom_task_id = f"{rci.generator_task_id_prefix}_{uuid4()}"
redis_connector_index.generator_clear()
# set a basic fence to start
fence_value = RedisConnectorIndexingFenceData(
payload = RedisConnectorIndexingFenceData(
index_attempt_id=None,
started=None,
submitted=datetime.now(timezone.utc),
celery_task_id=None,
)
r.set(rci.fence_key, fence_value.model_dump_json())
redis_connector_index.set_fence(payload)
# create the index attempt for tracking purposes
# code elsewhere checks for index attempts without an associated redis key
@@ -348,6 +345,8 @@ def try_creating_indexing_task(
db_session=db_session,
)
custom_task_id = redis_connector_index.generate_generator_task_id()
result = celery_app.send_task(
"connector_indexing_proxy_task",
kwargs=dict(
@@ -364,11 +363,12 @@ def try_creating_indexing_task(
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
# now fill out the fence with the rest of the data
fence_value.index_attempt_id = index_attempt_id
fence_value.celery_task_id = result.id
r.set(rci.fence_key, fence_value.model_dump_json())
payload.index_attempt_id = index_attempt_id
payload.celery_task_id = result.id
redis_connector_index.set_fence(payload)
except Exception:
r.delete(rci.fence_key)
redis_connector_index.set_fence(payload)
task_logger.exception(
f"Unexpected exception: "
f"tenant={tenant_id} "
@@ -486,6 +486,18 @@ def connector_indexing_task(
that the task transitioned to a "READY" state but the generator_complete_key doesn't exist.
This will cause the primary worker to abort the indexing attempt and clean up.
"""
# Since connector_indexing_proxy_task spawns a new process using this function as
# the entrypoint, we init Sentry here.
if SENTRY_DSN:
sentry_sdk.init(
dsn=SENTRY_DSN,
traces_sample_rate=0.1,
)
logger.info("Sentry initialized")
else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
logger.info(
f"Indexing spawned task starting: attempt={index_attempt_id} "
f"tenant={tenant_id} "
@@ -493,62 +505,60 @@ def connector_indexing_task(
f"search_settings={search_settings_id}"
)
attempt = None
n_final_progress = 0
attempt_found = False
n_final_progress: int | None = None
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
r = get_redis_client(tenant_id=tenant_id)
rcd = RedisConnectorDeletion(cc_pair_id)
if r.exists(rcd.fence_key):
if redis_connector.delete.fenced:
raise RuntimeError(
f"Indexing will not start because connector deletion is in progress: "
f"cc_pair={cc_pair_id} "
f"fence={rcd.fence_key}"
f"fence={redis_connector.delete.fence_key}"
)
rcs = RedisConnectorStop(cc_pair_id)
if r.exists(rcs.fence_key):
if redis_connector.stop.fenced:
raise RuntimeError(
f"Indexing will not start because a connector stop signal was detected: "
f"cc_pair={cc_pair_id} "
f"fence={rcs.fence_key}"
f"fence={redis_connector.stop.fence_key}"
)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
while True:
# read related data and evaluate/print task progress
fence_value = cast(bytes, r.get(rci.fence_key))
if fence_value is None:
# wait for the fence to come up
if not redis_connector_index.fenced:
raise ValueError(
f"connector_indexing_task: fence_value not found: fence={rci.fence_key}"
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
)
try:
fence_json = fence_value.decode("utf-8")
fence_data = RedisConnectorIndexingFenceData.model_validate_json(
cast(str, fence_json)
)
except ValueError:
logger.exception(
f"connector_indexing_task: fence_data not decodeable: fence={rci.fence_key}"
)
raise
payload = redis_connector_index.payload
if not payload:
raise ValueError("connector_indexing_task: payload invalid or not found")
if fence_data.index_attempt_id is None or fence_data.celery_task_id is None:
if payload.index_attempt_id is None or payload.celery_task_id is None:
logger.info(
f"connector_indexing_task - Waiting for fence: fence={rci.fence_key}"
f"connector_indexing_task - Waiting for fence: fence={redis_connector_index.fence_key}"
)
sleep(1)
continue
if payload.index_attempt_id != index_attempt_id:
raise ValueError(
f"connector_indexing_task - id mismatch. Task may be left over from previous run.: "
f"task_index_attempt={index_attempt_id} "
f"payload_index_attempt={payload.index_attempt_id}"
)
logger.info(
f"connector_indexing_task - Fence found, continuing...: fence={rci.fence_key}"
f"connector_indexing_task - Fence found, continuing...: fence={redis_connector_index.fence_key}"
)
break
lock = r.lock(
rci.generator_lock_key,
redis_connector_index.generator_lock_key,
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
)
@@ -558,11 +568,10 @@ def connector_indexing_task(
f"Indexing task already running, exiting...: "
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
)
# r.set(rci.generator_complete_key, HTTPStatus.CONFLICT.value)
return None
fence_data.started = datetime.now(timezone.utc)
r.set(rci.fence_key, fence_data.model_dump_json())
payload.started = datetime.now(timezone.utc)
redis_connector_index.set_fence(payload)
try:
with get_session_with_tenant(tenant_id) as db_session:
@@ -571,6 +580,7 @@ def connector_indexing_task(
raise ValueError(
f"Index attempt not found: index_attempt={index_attempt_id}"
)
attempt_found = True
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
@@ -590,37 +600,32 @@ def connector_indexing_task(
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}"
)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
# define a callback class
callback = RunIndexingCallback(
redis_connector.stop.fence_key,
redis_connector_index.generator_progress_key,
lock,
r,
)
# define a callback class
callback = RunIndexingCallback(
rcs.fence_key, rci.generator_progress_key, lock, r
)
logger.info(
f"Indexing spawned task running entrypoint: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
logger.info(
f"Indexing spawned task running entrypoint: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
run_indexing_entrypoint(
index_attempt_id,
tenant_id,
cc_pair_id,
is_ee,
callback=callback,
)
run_indexing_entrypoint(
index_attempt_id,
tenant_id,
cc_pair_id,
is_ee,
callback=callback,
)
# get back the total number of indexed docs and return it
generator_progress_value = r.get(rci.generator_progress_key)
if generator_progress_value is not None:
try:
n_final_progress = int(cast(int, generator_progress_value))
except ValueError:
pass
r.set(rci.generator_complete_key, HTTPStatus.OK.value)
# get back the total number of indexed docs and return it
n_final_progress = redis_connector_index.get_progress()
redis_connector_index.set_generator_complete(HTTPStatus.OK.value)
except Exception as e:
logger.exception(
f"Indexing spawned task failed: attempt={index_attempt_id} "
@@ -628,14 +633,10 @@ def connector_indexing_task(
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
if attempt:
if attempt_found:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
mark_attempt_failed(index_attempt_id, db_session, failure_reason=str(e))
r.delete(rci.generator_lock_key)
r.delete(rci.generator_progress_key)
r.delete(rci.taskset_key)
r.delete(rci.fence_key)
raise e
finally:
if lock.owned():

View File

@@ -11,9 +11,6 @@ from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
@@ -33,6 +30,7 @@ from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import pruning_ctx
from danswer.utils.logger import setup_logger
@@ -147,8 +145,11 @@ def try_creating_prune_generator_task(
is used to trigger prunes immediately, e.g. via the web ui.
"""
redis_connector = RedisConnector(tenant_id, cc_pair.id)
if not ALLOW_SIMULTANEOUS_PRUNING:
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
count = redis_connector.prune.get_active_task_count()
if count > 0:
return None
LOCK_TIMEOUT = 30
@@ -165,15 +166,10 @@ def try_creating_prune_generator_task(
return None
try:
rcp = RedisConnectorPruning(cc_pair.id)
# skip pruning if already pruning
if r.exists(rcp.fence_key):
if redis_connector.prune.fenced: # skip pruning if already pruning
return None
# skip pruning if the cc_pair is deleting
rcd = RedisConnectorDeletion(cc_pair.id)
if r.exists(rcd.fence_key):
if redis_connector.delete.fenced: # skip pruning if the cc_pair is deleting
return None
db_session.refresh(cc_pair)
@@ -181,10 +177,10 @@ def try_creating_prune_generator_task(
return None
# add a long running generator task to the queue
r.delete(rcp.generator_complete_key)
r.delete(rcp.taskset_key)
redis_connector.prune.generator_clear()
redis_connector.prune.taskset_clear()
custom_task_id = f"{rcp.generator_task_id_prefix}_{uuid4()}"
custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}"
celery_app.send_task(
"connector_pruning_generator_task",
@@ -200,7 +196,7 @@ def try_creating_prune_generator_task(
)
# set this only after all tasks have been added
r.set(rcp.fence_key, 1)
redis_connector.prune.set_fence(True)
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}")
return None
@@ -235,12 +231,12 @@ def connector_pruning_generator_task(
pruning_ctx_dict["request_id"] = self.request.id
pruning_ctx.set(pruning_ctx_dict)
rcp = RedisConnectorPruning(cc_pair_id)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
lock = r.lock(
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{rcp._id}",
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
)
@@ -273,10 +269,11 @@ def connector_pruning_generator_task(
cc_pair.credential,
)
rcs = RedisConnectorStop(cc_pair_id)
callback = RunIndexingCallback(
rcs.fence_key, rcp.generator_progress_key, lock, r
redis_connector.stop.fence_key,
redis_connector.prune.generator_progress_key,
lock,
r,
)
# a list of docs in the source
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
@@ -303,31 +300,29 @@ def connector_pruning_generator_task(
f"doc_source={cc_pair.connector.source}"
)
rcp.documents_to_prune = set(doc_ids_to_remove)
task_logger.info(
f"RedisConnectorPruning.generate_tasks starting. cc_pair={cc_pair.id}"
f"RedisConnector.prune.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = rcp.generate_tasks(
self.app, db_session, r, None, tenant_id
tasks_generated = redis_connector.prune.generate_tasks(
set(doc_ids_to_remove), self.app, db_session, None
)
if tasks_generated is None:
return None
task_logger.info(
f"RedisConnectorPruning.generate_tasks finished. "
f"cc_pair={cc_pair.id} tasks_generated={tasks_generated}"
f"RedisConnector.prune.generate_tasks finished. "
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
)
r.set(rcp.generator_complete_key, tasks_generated)
redis_connector.prune.generator_complete = tasks_generated
except Exception as e:
task_logger.exception(
f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
)
r.delete(rcp.generator_progress_key)
r.delete(rcp.taskset_key)
r.delete(rcp.fence_key)
redis_connector.prune.generator_clear()
redis_connector.prune.taskset_clear()
redis_connector.prune.set_fence(False)
raise e
finally:
if lock.owned():

View File

@@ -1,8 +0,0 @@
from datetime import datetime
from pydantic import BaseModel
class RedisConnectorDeletionFenceData(BaseModel):
num_tasks: int | None
submitted: datetime

View File

@@ -1,10 +0,0 @@
from datetime import datetime
from pydantic import BaseModel
class RedisConnectorIndexingFenceData(BaseModel):
index_attempt_id: int | None
started: datetime | None
submitted: datetime
celery_task_id: str | None

View File

@@ -19,18 +19,6 @@ from tenacity import RetryError
from danswer.access.access import get_access_for_document
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import celery_get_queue_length
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
RedisConnectorDeletionFenceData,
)
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
RedisConnectorIndexingFenceData,
)
from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from danswer.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from danswer.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
@@ -67,7 +55,14 @@ from danswer.db.models import IndexAttempt
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_document_set import RedisDocumentSet
from danswer.redis.redis_pool import get_redis_client
from danswer.redis.redis_usergroup import RedisUserGroup
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import (
@@ -192,7 +187,7 @@ def try_generate_stale_document_sync_tasks(
total_tasks_generated = 0
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
rc = RedisConnectorCredentialPair(cc_pair.id)
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
tasks_generated = rc.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
@@ -228,10 +223,10 @@ def try_generate_document_set_sync_tasks(
) -> int | None:
lock_beat.reacquire()
rds = RedisDocumentSet(document_set_id)
rds = RedisDocumentSet(tenant_id, document_set_id)
# don't generate document set sync tasks if tasks are still pending
if r.exists(rds.fence_key):
if rds.fenced:
return None
# don't generate sync tasks if we're up to date
@@ -269,7 +264,7 @@ def try_generate_document_set_sync_tasks(
)
# set this only after all tasks have been added
r.set(rds.fence_key, tasks_generated)
rds.set_fence(tasks_generated)
return tasks_generated
@@ -283,10 +278,9 @@ def try_generate_user_group_sync_tasks(
) -> int | None:
lock_beat.reacquire()
rug = RedisUserGroup(usergroup_id)
# don't generate sync tasks if tasks are still pending
if r.exists(rug.fence_key):
rug = RedisUserGroup(tenant_id, usergroup_id)
if rug.fenced:
# don't generate sync tasks if tasks are still pending
return None
# race condition with the monitor/cleanup function if we use a cached result!
@@ -326,7 +320,7 @@ def try_generate_user_group_sync_tasks(
)
# set this only after all tasks have been added
r.set(rug.fence_key, tasks_generated)
rug.set_fence(tasks_generated)
return tasks_generated
@@ -352,7 +346,7 @@ def monitor_connector_taskset(r: Redis) -> None:
def monitor_document_set_taskset(
key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
@@ -362,16 +356,12 @@ def monitor_document_set_taskset(
document_set_id = int(document_set_id_str)
rds = RedisDocumentSet(document_set_id)
fence_value = r.get(rds.fence_key)
if fence_value is None:
rds = RedisDocumentSet(tenant_id, document_set_id)
if not rds.fenced:
return
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
initial_count = rds.payload
if initial_count is None:
return
count = cast(int, r.scard(rds.taskset_key))
@@ -399,48 +389,38 @@ def monitor_document_set_taskset(
f"Successfully synced document set: document_set={document_set_id}"
)
r.delete(rds.taskset_key)
r.delete(rds.fence_key)
rds.reset()
def monitor_connector_deletion_taskset(
key_bytes: bytes, r: Redis, tenant_id: str | None
tenant_id: str | None, key_bytes: bytes, r: Redis
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
return
cc_pair_id = int(cc_pair_id_str)
rcd = RedisConnectorDeletion(cc_pair_id)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
# read related data and evaluate/print task progress
fence_value = cast(bytes, r.get(rcd.fence_key))
if fence_value is None:
fence_data = redis_connector.delete.payload
if not fence_data:
task_logger.warning(
f"Connector deletion - fence payload invalid: cc_pair={cc_pair_id}"
)
return
try:
fence_json = fence_value.decode("utf-8")
fence_data = RedisConnectorDeletionFenceData.model_validate_json(
cast(str, fence_json)
)
except ValueError:
task_logger.exception(
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
)
raise
# the fence is setting up but isn't ready yet
if fence_data.num_tasks is None:
# the fence is setting up but isn't ready yet
return
count = cast(int, r.scard(rcd.taskset_key))
remaining = redis_connector.delete.get_remaining()
task_logger.info(
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={fence_data.num_tasks}"
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
)
if count > 0:
if remaining > 0:
return
with get_session_with_tenant(tenant_id) as db_session:
@@ -524,15 +504,15 @@ def monitor_connector_deletion_taskset(
f"docs_deleted={fence_data.num_tasks}"
)
r.delete(rcd.taskset_key)
r.delete(rcd.fence_key)
redis_connector.delete.taskset_clear()
redis_connector.delete.set_fence(None)
def monitor_ccpair_pruning_taskset(
key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnectorPruning.get_id_from_fence_key(fence_key)
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
@@ -541,46 +521,37 @@ def monitor_ccpair_pruning_taskset(
cc_pair_id = int(cc_pair_id_str)
rcp = RedisConnectorPruning(cc_pair_id)
fence_value = r.get(rcp.fence_key)
if fence_value is None:
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if not redis_connector.prune.fenced:
return
generator_value = r.get(rcp.generator_complete_key)
if generator_value is None:
initial = redis_connector.prune.generator_complete
if initial is None:
return
try:
initial_count = int(cast(int, generator_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rcp.taskset_key))
remaining = redis_connector.prune.get_remaining()
task_logger.info(
f"Connector pruning progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
f"Connector pruning progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
)
if count > 0:
if remaining > 0:
return
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
task_logger.info(
f"Successfully pruned connector credential pair. cc_pair_id={cc_pair_id}"
f"Successfully pruned connector credential pair. cc_pair={cc_pair_id}"
)
r.delete(rcp.taskset_key)
r.delete(rcp.generator_progress_key)
r.delete(rcp.generator_complete_key)
r.delete(rcp.fence_key)
redis_connector.prune.taskset_clear()
redis_connector.prune.generator_clear()
redis_connector.prune.set_fence(False)
def monitor_ccpair_indexing_taskset(
key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
composite_id = RedisConnectorIndexing.get_id_from_fence_key(fence_key)
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"monitor_ccpair_indexing_taskset: could not parse composite_id from {fence_key}"
@@ -595,53 +566,37 @@ def monitor_ccpair_indexing_taskset(
cc_pair_id = int(parts[0])
search_settings_id = int(parts[1])
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
# read related data and evaluate/print task progress
fence_value = cast(bytes, r.get(rci.fence_key))
if fence_value is None:
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
if not redis_connector_index.fenced:
return
try:
fence_json = fence_value.decode("utf-8")
fence_data = RedisConnectorIndexingFenceData.model_validate_json(
cast(str, fence_json)
payload = redis_connector_index.payload
if not payload:
return
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
progress = redis_connector_index.get_progress()
if progress is not None:
task_logger.info(
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"progress={progress} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
except ValueError:
task_logger.exception(
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
)
raise
elapsed_submitted = datetime.now(timezone.utc) - fence_data.submitted
generator_progress_value = r.get(rci.generator_progress_key)
if generator_progress_value is not None:
try:
progress_count = int(cast(int, generator_progress_value))
task_logger.info(
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"progress={progress_count} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
except ValueError:
task_logger.error(
"monitor_ccpair_indexing_taskset: generator_progress_value is not an integer."
)
if fence_data.index_attempt_id is None or fence_data.celery_task_id is None:
if payload.index_attempt_id is None or payload.celery_task_id is None:
# the task is still setting up
return
# Read result state BEFORE generator_complete_key to avoid a race condition
# never use any blocking methods on the result from inside a task!
result: AsyncResult = AsyncResult(fence_data.celery_task_id)
result: AsyncResult = AsyncResult(payload.celery_task_id)
result_state = result.state
generator_complete_value = r.get(rci.generator_complete_key)
if generator_complete_value is None:
status_int = redis_connector_index.get_completion()
if status_int is None:
if result_state in READY_STATES:
# IF the task state is READY, THEN generator_complete should be set
# if it isn't, then the worker crashed
@@ -652,30 +607,18 @@ def monitor_ccpair_indexing_taskset(
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
index_attempt = get_index_attempt(db_session, fence_data.index_attempt_id)
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
if index_attempt:
mark_attempt_failed(
index_attempt=index_attempt,
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason="Connector indexing aborted or exceptioned.",
)
r.delete(rci.generator_lock_key)
r.delete(rci.taskset_key)
r.delete(rci.generator_progress_key)
r.delete(rci.generator_complete_key)
r.delete(rci.fence_key)
redis_connector_index.reset()
return
status_enum = HTTPStatus.INTERNAL_SERVER_ERROR
try:
status_value = int(cast(int, generator_complete_value))
status_enum = HTTPStatus(status_value)
except ValueError:
task_logger.error(
f"monitor_ccpair_indexing_taskset: "
f"generator_complete_value=f{generator_complete_value} could not be parsed."
)
status_enum = HTTPStatus(status_int)
task_logger.info(
f"Connector indexing finished: cc_pair_id={cc_pair_id} "
@@ -684,11 +627,7 @@ def monitor_ccpair_indexing_taskset(
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
r.delete(rci.generator_lock_key)
r.delete(rci.taskset_key)
r.delete(rci.generator_progress_key)
r.delete(rci.generator_complete_key)
r.delete(rci.fence_key)
redis_connector_index.reset()
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
@@ -700,7 +639,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
do anything too expensive in this function!
Returns True if the task actually did work, False
Returns True if the task actually did work, False if it exited early to prevent overlap
"""
r = get_redis_client(tenant_id=tenant_id)
@@ -751,27 +690,33 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
for a in attempts:
# if attempts exist in the db but we don't detect them in redis, mark them as failed
rci = RedisConnectorIndexing(
fence_key = RedisConnectorIndex.fence_key_with_ids(
a.connector_credential_pair_id, a.search_settings_id
)
failure_reason = f"Unknown index attempt {a.id}. Might be left over from a process restart."
if not r.exists(rci.fence_key):
mark_attempt_failed(a, db_session, failure_reason=failure_reason)
if not r.exists(fence_key):
failure_reason = (
f"Unknown index attempt. Might be left over from a process restart: "
f"index_attempt={a.id} "
f"cc_pair={a.connector_credential_pair_id} "
f"search_settings={a.search_settings_id}"
)
task_logger.warning(failure_reason)
mark_attempt_failed(a.id, db_session, failure_reason=failure_reason)
lock_beat.reacquire()
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
for key_bytes in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"):
lock_beat.reacquire()
monitor_connector_deletion_taskset(key_bytes, r, tenant_id)
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_document_set_taskset(key_bytes, r, db_session)
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
@@ -782,19 +727,19 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
noop_fallback,
)
with get_session_with_tenant(tenant_id) as db_session:
monitor_usergroup_taskset(key_bytes, r, db_session)
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
for key_bytes in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(key_bytes, r, db_session)
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
# uncomment for debugging if needed
# r_celery = celery_app.broker_connection().channel().client

View File

@@ -1,8 +1,6 @@
"""Factory stub for running celery worker / celery beat."""
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.background.celery.apps.beat import celery_app
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app = fetch_versioned_implementation(
"danswer.background.celery.apps.beat", "celery_app"
)
app = celery_app

View File

@@ -118,7 +118,13 @@ def _run_indexing(
"""
start_time = time.time()
if index_attempt.search_settings is None:
raise ValueError(
"Search settings must be set for indexing. This should not be possible."
)
search_settings = index_attempt.search_settings
index_name = search_settings.index_name
# Only update cc-pair status for primary index jobs
@@ -331,7 +337,7 @@ def _run_indexing(
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
index_attempt,
index_attempt.id,
db_session,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
@@ -366,7 +372,7 @@ def _run_indexing(
and index_attempt_md.num_exceptions >= batch_num
):
mark_attempt_failed(
index_attempt,
index_attempt.id,
db_session,
failure_reason="All batches exceptioned.",
)

View File

@@ -0,0 +1,4 @@
def name_sync_external_doc_permissions_task(
cc_pair_id: int, tenant_id: str | None = None
) -> str:
return f"sync_external_doc_permissions_task__{cc_pair_id}"

View File

@@ -156,7 +156,7 @@ class QAResponse(SearchResponse, DanswerAnswer):
error_msg: str | None = None
class ImageGenerationDisplay(BaseModel):
class FileChatDisplay(BaseModel):
file_ids: list[str]
@@ -170,7 +170,7 @@ AnswerQuestionPossibleReturn = (
| DanswerQuotes
| CitationInfo
| DanswerContexts
| ImageGenerationDisplay
| FileChatDisplay
| CustomToolResponse
| StreamingError
| StreamStopInfo

View File

@@ -42,18 +42,14 @@ personas:
display_priority: 1
is_visible: true
starter_messages:
- name: "General Information"
description: "Ask about available information"
message: "Hello! I'm interested in learning more about the information available here. Could you give me an overview of the types of data or documents that might be accessible?"
- name: "Specific Topic Search"
description: "Search for specific information"
message: "Hi! I'd like to learn more about a specific topic. Could you help me find relevant documents and information?"
- name: "Recent Updates"
description: "Inquire about latest additions"
message: "Hello! I'm curious about any recent updates or additions to the knowledge base. Can you tell me what new information has been added lately?"
- name: "Cross-referencing Information"
description: "Connect information from different sources"
message: "Hi! I'm working on a project that requires connecting information from multiple sources. How can I effectively cross-reference data across different documents or categories?"
- name: "Give me an overview of what's here"
message: "Sample some documents and tell me what you find."
- name: "Use AI to solve a work related problem"
message: "Ask me what problem I would like to solve, then search the knowledge base to help me find a solution."
- name: "Find updates on a topic of interest"
message: "Once I provide a topic, retrieve related documents and tell me when there was last activity on the topic if available."
- name: "Surface contradictions"
message: "Have me choose a subject. Once I have provided it, check against the knowledge base and point out any inconsistencies. For all your following responses, focus on identifying contradictions."
- id: 1
name: "General"
@@ -71,18 +67,14 @@ personas:
display_priority: 0
is_visible: true
starter_messages:
- name: "Open Discussion"
description: "Start an open-ended conversation"
message: "Hi! Can you help me write a professional email?"
- name: "Problem Solving"
description: "Get help with a challenge"
message: "Hello! I need help managing my daily tasks better. Do you have any simple tips?"
- name: "Learn Something New"
description: "Explore a new topic"
message: "Hi! Could you explain what project management is in simple terms?"
- name: "Creative Brainstorming"
description: "Generate creative ideas"
message: "Hello! I need to brainstorm some team building activities. Do you have any fun suggestions?"
- name: "Summarize a document"
message: "If I have provided a document please summarize it for me. If not, please ask me to upload a document either by dragging it into the input bar or clicking the +file icon."
- name: "Help me with coding"
message: 'Write me a "Hello World" script in 5 random languages to show off the functionality.'
- name: "Draft a professional email"
message: "Help me craft a professional email. Let's establish the context and the anticipated outcomes of the email before proposing a draft."
- name: "Learn something new"
message: "What is the difference between a Gantt chart, a Burndown chart and a Kanban board?"
- id: 2
name: "Paraphrase"
@@ -101,16 +93,12 @@ personas:
is_visible: false
starter_messages:
- name: "Document Search"
description: "Find exact information"
message: "Hi! Could you help me find information about our team structure and reporting lines from our internal documents?"
- name: "Process Verification"
description: "Find exact quotes"
message: "Hello! I need to understand our project approval process. Could you find the exact steps from our documentation?"
- name: "Technical Documentation"
description: "Search technical details"
message: "Hi there! I'm looking for information about our deployment procedures. Can you find the specific steps from our technical guides?"
- name: "Policy Reference"
description: "Check official policies"
message: "Hello! Could you help me find our official guidelines about client communication? I need the exact wording from our documentation."
- id: 3
@@ -130,15 +118,11 @@ personas:
display_priority: 3
is_visible: true
starter_messages:
- name: "Landscape"
description: "Generate a landscape image"
message: "Create an image of a serene mountain lake at sunset, with snow-capped peaks reflected in the calm water and a small wooden cabin on the shore."
- name: "Character"
description: "Generate a character image"
message: "Generate an image of a futuristic robot with glowing blue eyes, sleek metallic body, and intricate circuitry visible through transparent panels on its chest and arms."
- name: "Abstract"
description: "Create an abstract image"
message: "Create an abstract image representing the concept of time, using swirling clock hands, fragmented hourglasses, and streaks of light to convey the passage of moments and eras."
- name: "Urban Scene"
description: "Generate an urban landscape"
message: "Generate an image of a bustling futuristic cityscape at night, with towering skyscrapers, flying vehicles, holographic advertisements, and a mix of neon and bioluminescent lighting."
- name: "Create visuals for a presentation"
message: "Generate someone presenting a graph which clearly demonstrates an upwards trajectory."
- name: "Find inspiration for a marketing campaign"
message: "Generate an image of two happy individuals sipping on a soda drink in a glass bottle."
- name: "Visualize a product design"
message: "I want to add a search bar to my Iphone app. Generate me generic examples of how other apps implement this."
- name: "Generate a humorous image response"
message: "My teammate just made a silly mistake and I want to respond with a facepalm. Can you generate me one?"

View File

@@ -11,24 +11,18 @@ from danswer.chat.models import AllCitations
from danswer.chat.models import CitationInfo
from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import FileChatDisplay
from danswer.chat.models import FinalUsedContextDocsResponse
from danswer.chat.models import ImageGenerationDisplay
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import MessageResponseIDInfo
from danswer.chat.models import MessageSpecificCitations
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.chat.models import StreamStopInfo
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.chat import attach_files_to_chat_message
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
@@ -41,7 +35,6 @@ from danswer.db.chat import reserve_message_id
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
@@ -61,14 +54,13 @@ from danswer.llm.answering.models import PromptConfig
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import litellm_exception_to_error_msg
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import InferenceSection
from danswer.search.models import RetrievalDetails
from danswer.search.retrieval.search_runner import inference_sections_from_ids
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
@@ -77,14 +69,14 @@ from danswer.search.utils import relevant_sections_to_indices
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.force import ForceUseTool
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
from danswer.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.tools.tool_constructor import construct_tools
from danswer.tools.tool_constructor import CustomToolConfig
from danswer.tools.tool_constructor import ImageGenerationToolConfig
from danswer.tools.tool_constructor import InternetSearchToolConfig
from danswer.tools.tool_constructor import SearchToolConfig
from danswer.tools.tool_implementations.custom.custom_tool import (
CUSTOM_TOOL_RESPONSE_ID,
)
@@ -95,9 +87,6 @@ from danswer.tools.tool_implementations.images.image_generation_tool import (
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationResponse,
)
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
INTERNET_SEARCH_RESPONSE_ID,
)
@@ -122,9 +111,6 @@ from danswer.tools.tool_implementations.search.search_tool import (
SECTION_RELEVANCE_LIST_ID,
)
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.utils import compute_all_tool_tokens
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.headers import header_dict_to_header_list
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
@@ -275,7 +261,7 @@ ChatPacket = (
| DanswerAnswerPiece
| AllCitations
| CitationInfo
| ImageGenerationDisplay
| FileChatDisplay
| CustomToolResponse
| MessageSpecificCitations
| MessageResponseIDInfo
@@ -295,7 +281,6 @@ def stream_chat_message_objects(
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
# if specified, uses the last user message and does not create a new user message based
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
@@ -307,6 +292,9 @@ def stream_chat_message_objects(
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
use_existing_user_message = new_msg_req.use_existing_user_message
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
# Currently surrounding context is not supported for chat
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
new_msg_req.chunks_above = 0
@@ -428,12 +416,20 @@ def stream_chat_message_objects(
final_msg, history_msgs = create_chat_chain(
chat_session_id=chat_session_id, db_session=db_session
)
if final_msg.message_type != MessageType.USER:
raise RuntimeError(
"The last message was not a user message. Cannot call "
"`stream_chat_message_objects` with `is_regenerate=True` "
"when the last message is not a user message."
)
if existing_assistant_message_id is None:
if final_msg.message_type != MessageType.USER:
raise RuntimeError(
"The last message was not a user message. Cannot call "
"`stream_chat_message_objects` with `is_regenerate=True` "
"when the last message is not a user message."
)
else:
if final_msg.id != existing_assistant_message_id:
raise RuntimeError(
"The last message was not the existing assistant message. "
f"Final message id: {final_msg.id}, "
f"existing assistant message id: {existing_assistant_message_id}"
)
# Disable Query Rephrasing for the first message
# This leads to a better first response since the LLM rephrasing the question
@@ -504,13 +500,19 @@ def stream_chat_message_objects(
),
max_window_percentage=max_document_percentage,
)
reserved_message_id = reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=user_message.id
if user_message is not None
else parent_message.id,
message_type=MessageType.ASSISTANT,
# we don't need to reserve a message id if we're using an existing assistant message
reserved_message_id = (
final_msg.id
if existing_assistant_message_id is not None
else reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=user_message.id
if user_message is not None
else parent_message.id,
message_type=MessageType.ASSISTANT,
)
)
yield MessageResponseIDInfo(
user_message_id=user_message.id if user_message else None,
@@ -525,7 +527,13 @@ def stream_chat_message_objects(
partial_response = partial(
create_new_chat_message,
chat_session_id=chat_session_id,
parent_message=final_msg,
# if we're using an existing assistant message, then this will just be an
# update operation, in which case the parent should be the parent of
# the latest. If we're creating a new assistant message, then the parent
# should be the latest message (latest user message)
parent_message=(
final_msg if existing_assistant_message_id is None else parent_message
),
prompt_id=prompt_id,
overridden_model=overridden_model,
# message=,
@@ -537,6 +545,7 @@ def stream_chat_message_objects(
# reference_docs=,
db_session=db_session,
commit=False,
reserved_message_id=reserved_message_id,
)
if not final_msg.prompt:
@@ -560,142 +569,39 @@ def stream_chat_message_objects(
structured_response_format=new_msg_req.structured_response_format,
)
# find out what tools to use
search_tool: SearchTool | None = None
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
for db_tool_model in persona.tools:
# handle in-code tools specially
if db_tool_model.in_code_tool_id:
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
search_tool = SearchTool(
db_session=db_session,
user=user,
persona=persona,
retrieval_options=retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
answer_style_config=answer_style_config,
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
)
tool_dict[db_tool_model.id] = [search_tool]
elif tool_cls.__name__ == ImageGenerationTool.__name__:
img_generation_llm_config: LLMConfig | None = None
if (
llm
and llm.config.api_key
and llm.config.model_provider == "openai"
):
img_generation_llm_config = LLMConfig(
model_provider=llm.config.model_provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=llm.config.api_key,
api_base=llm.config.api_base,
api_version=llm.config.api_version,
)
elif (
llm.config.model_provider == "azure"
and AZURE_DALLE_API_KEY is not None
):
img_generation_llm_config = LLMConfig(
model_provider="azure",
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
temperature=GEN_AI_TEMPERATURE,
api_key=AZURE_DALLE_API_KEY,
api_base=AZURE_DALLE_API_BASE,
api_version=AZURE_DALLE_API_VERSION,
)
else:
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
iter(
[
llm_provider
for llm_provider in llm_providers
if llm_provider.provider == "openai"
]
),
None,
)
if not openai_provider or not openai_provider.api_key:
raise ValueError(
"Image generation tool requires an OpenAI API key"
)
img_generation_llm_config = LLMConfig(
model_provider=openai_provider.provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=openai_provider.api_key,
api_base=openai_provider.api_base,
api_version=openai_provider.api_version,
)
tool_dict[db_tool_model.id] = [
ImageGenerationTool(
api_key=cast(str, img_generation_llm_config.api_key),
api_base=img_generation_llm_config.api_base,
api_version=img_generation_llm_config.api_version,
additional_headers=litellm_additional_headers,
model=img_generation_llm_config.model_name,
)
]
elif tool_cls.__name__ == InternetSearchTool.__name__:
bing_api_key = BING_API_KEY
if not bing_api_key:
raise ValueError(
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
)
tool_dict[db_tool_model.id] = [
InternetSearchTool(
api_key=bing_api_key,
answer_style_config=answer_style_config,
prompt_config=prompt_config,
)
]
continue
# handle all custom tools
if db_tool_model.openapi_schema:
tool_dict[db_tool_model.id] = cast(
list[Tool],
build_custom_tools_from_openapi_schema_and_headers(
db_tool_model.openapi_schema,
dynamic_schema_info=DynamicSchemaInfo(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
),
custom_headers=(db_tool_model.custom_headers or [])
+ (
header_dict_to_header_list(
custom_tool_additional_headers or {}
)
),
),
)
tool_dict = construct_tools(
persona=persona,
prompt_config=prompt_config,
db_session=db_session,
user=user,
llm=llm,
fast_llm=fast_llm,
search_tool_config=SearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
retrieval_options=retrieval_options or RetrievalDetails(),
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
latest_query_files=latest_query_files,
),
internet_search_tool_config=InternetSearchToolConfig(
answer_style_config=answer_style_config,
),
image_generation_tool_config=ImageGenerationToolConfig(
additional_headers=litellm_additional_headers,
),
custom_tool_config=CustomToolConfig(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
additional_headers=custom_tool_additional_headers,
),
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)
# factor in tool definition size when pruning
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
tools, llm_tokenizer
)
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
llm_provider, llm_model_name
)
# LLM prompt building, response capturing, etc.
answer = Answer(
is_connected=is_connected,
@@ -729,9 +635,6 @@ def stream_chat_message_objects(
tool_result = None
for packet in answer.processed_streamed_output:
if isinstance(packet, StreamStopInfo):
break
if isinstance(packet, ToolResponse):
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
(
@@ -772,7 +675,6 @@ def stream_chat_message_objects(
yield LLMRelevanceFilterResponse(
llm_selected_doc_indices=llm_indices
)
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(
final_context_docs=packet.response
@@ -790,7 +692,7 @@ def stream_chat_message_objects(
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
for file_id in file_ids
]
yield ImageGenerationDisplay(
yield FileChatDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
@@ -804,11 +706,32 @@ def stream_chat_message_objects(
yield qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
if (
custom_tool_response.response_type == "image"
or custom_tool_response.response_type == "csv"
):
file_ids = custom_tool_response.tool_result.file_ids
ai_message_files = [
FileDescriptor(
id=str(file_id),
type=ChatFileType.IMAGE
if custom_tool_response.response_type == "image"
else ChatFileType.CSV,
)
for file_id in file_ids
]
yield FileChatDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
else:
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
elif isinstance(packet, StreamStopInfo):
pass
else:
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
@@ -854,7 +777,6 @@ def stream_chat_message_objects(
tool_name_to_tool_id[tool.name] = tool_id
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
message=answer.llm_answer,
rephrased_query=(
qa_docs_response.rephrased_query if qa_docs_response else None
@@ -862,9 +784,11 @@ def stream_chat_message_objects(
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=message_specific_citations.citation_map
if message_specific_citations
else None,
citations=(
message_specific_citations.citation_map
if message_specific_citations
else None
),
error=None,
tool_call=(
ToolCall(
@@ -898,7 +822,6 @@ def stream_chat_message_objects(
def stream_chat_message(
new_msg_req: CreateChatMessageRequest,
user: User | None,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
@@ -908,7 +831,6 @@ def stream_chat_message(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
custom_tool_additional_headers=custom_tool_additional_headers,
is_connected=is_connected,

View File

@@ -9,19 +9,19 @@ prompts:
system: >
You are a question answering system that is constantly learning and improving.
The current date is DANSWER_DATETIME_REPLACEMENT.
You can process and comprehend vast amounts of text and utilize this knowledge to provide
grounded, accurate, and concise answers to diverse queries.
You always clearly communicate ANY UNCERTAINTY in your answer.
# Task Prompt (as shown in UI)
task: >
Answer my query based on the documents provided.
The documents may not all be relevant, ignore any documents that are not directly relevant
to the most recent user query.
I have not read or seen any of the documents and do not want to read them.
If there are no relevant documents, refer to the chat history and your internal knowledge.
# Inject a statement at the end of system prompt to inform the LLM of the current date/time
# If the DANSWER_DATETIME_REPLACEMENT is set, the date/time is inserted there instead
@@ -30,21 +30,21 @@ prompts:
# Prompts the LLM to include citations in the for [1], [2] etc.
# which get parsed to match the passed in sources
include_citations: true
- name: "ImageGeneration"
description: "Generates images based on user prompts!"
description: "Generates images from user descriptions!"
system: >
You are an advanced image generation system capable of creating diverse and detailed images.
You can interpret user prompts and generate high-quality, creative images that match their descriptions.
You always strive to create safe and appropriate content, avoiding any harmful or offensive imagery.
You are an AI image generation assistant. Your role is to create high-quality images based on user descriptions.
For appropriate requests, you will generate an image that matches the user's requirements.
For inappropriate or unsafe requests, you will politely decline and explain why the request cannot be fulfilled.
You aim to be helpful while maintaining appropriate content standards.
task: >
Generate an image based on the user's description.
Provide a detailed description of the generated image, including key elements, colors, and composition.
If the request is not possible or appropriate, explain why and suggest alternatives.
Based on the user's description, create a high-quality image that accurately reflects their request.
Pay close attention to the specified details, styles, and desired elements.
If the request is not appropriate or cannot be fulfilled, explain why and suggest alternatives.
datetime_aware: true
include_citations: false
@@ -64,14 +64,13 @@ prompts:
datetime_aware: true
include_citations: true
- name: "Summarize"
description: "Summarize relevant information from retrieved context!"
system: >
You are a text summarizing assistant that highlights the most important knowledge from the
context provided, prioritizing the information that relates to the user query.
The current date is DANSWER_DATETIME_REPLACEMENT.
You ARE NOT creative and always stick to the provided documents.
If there are no documents, refer to the conversation history.
@@ -84,7 +83,6 @@ prompts:
datetime_aware: true
include_citations: true
- name: "Paraphrase"
description: "Recites information from retrieved context! Least creative but most safe!"
system: >
@@ -92,10 +90,10 @@ prompts:
The current date is DANSWER_DATETIME_REPLACEMENT.
You only provide quotes that are EXACT substrings from provided documents!
If there are no documents provided,
simply tell the user that there are no documents to reference.
You NEVER generate new text or phrases outside of the citation.
DO NOT explain your responses, only provide the quotes and NOTHING ELSE.
task: >

View File

@@ -163,6 +163,17 @@ try:
except ValueError:
POSTGRES_POOL_RECYCLE = POSTGRES_POOL_RECYCLE_DEFAULT
# Experimental setting to control idle transactions
POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT = 0 # milliseconds
try:
POSTGRES_IDLE_SESSIONS_TIMEOUT = int(
os.environ.get(
"POSTGRES_IDLE_SESSIONS_TIMEOUT", POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT
)
)
except ValueError:
POSTGRES_IDLE_SESSIONS_TIMEOUT = POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT
REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true"
REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
@@ -251,9 +262,6 @@ ENABLED_CONNECTOR_TYPES = os.environ.get("ENABLED_CONNECTOR_TYPES") or ""
# for some connectors
ENABLE_EXPENSIVE_EXPERT_CALLS = False
GOOGLE_DRIVE_INCLUDE_SHARED = False
GOOGLE_DRIVE_FOLLOW_SHORTCUTS = False
GOOGLE_DRIVE_ONLY_ORG_PUBLIC = False
# TODO these should be available for frontend configuration, via advanced options expandable
WEB_CONNECTOR_IGNORED_CLASSES = os.environ.get(
@@ -481,3 +489,17 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
# JWT configuration
JWT_ALGORITHM = "HS256"
# Super Users
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
#####
# API Key Configs
#####
# refers to the rounds described here: https://passlib.readthedocs.io/en/stable/lib/passlib.hash.sha256_crypt.html
_API_KEY_HASH_ROUNDS_RAW = os.environ.get("API_KEY_HASH_ROUNDS")
API_KEY_HASH_ROUNDS = (
int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None
)

View File

@@ -125,6 +125,8 @@ class DocumentSource(str, Enum):
OCI_STORAGE = "oci_storage"
XENFORO = "xenforo"
NOT_APPLICABLE = "not_applicable"
FRESHDESK = "freshdesk"
FIREFLIES = "fireflies"
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
@@ -224,6 +226,9 @@ class DanswerRedisLocks:
PRUNING_LOCK_PREFIX = "da_lock:pruning"
INDEXING_METADATA_PREFIX = "da_metadata:indexing"
SLACK_BOT_LOCK = "da_lock:slack_bot"
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
class DanswerCeleryPriority(int, Enum):
HIGHEST = 0

View File

@@ -17,6 +17,7 @@ from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
@@ -249,7 +250,11 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
self.cql_time_filter += f" and lastmodified <= '{formatted_end_time}'"
return self._fetch_document_batches()
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")

View File

@@ -16,6 +16,8 @@ from danswer.connectors.discourse.connector import DiscourseConnector
from danswer.connectors.document360.connector import Document360Connector
from danswer.connectors.dropbox.connector import DropboxConnector
from danswer.connectors.file.connector import LocalFileConnector
from danswer.connectors.fireflies.connector import FirefliesConnector
from danswer.connectors.freshdesk.connector import FreshdeskConnector
from danswer.connectors.github.connector import GithubConnector
from danswer.connectors.gitlab.connector import GitlabConnector
from danswer.connectors.gmail.connector import GmailConnector
@@ -99,6 +101,8 @@ def identify_connector_class(
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
DocumentSource.OCI_STORAGE: BlobStorageConnector,
DocumentSource.XENFORO: XenforoConnector,
DocumentSource.FRESHDESK: FreshdeskConnector,
DocumentSource.FIREFLIES: FirefliesConnector,
}
connector_by_source = connector_map.get(source, {})

View File

@@ -27,8 +27,8 @@ from danswer.file_processing.extract_file_text import read_pdf_file
from danswer.file_processing.extract_file_text import read_text_file
from danswer.file_store.file_store import get_default_file_store
from danswer.utils.logger import setup_logger
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -123,9 +123,13 @@ def _process_file(
"filename",
"file_display_name",
"title",
"connector_type",
]
}
source_type_str = all_metadata.get("connector_type")
source_type = DocumentSource(source_type_str) if source_type_str else None
p_owner_names = all_metadata.get("primary_owners")
s_owner_names = all_metadata.get("secondary_owners")
p_owners = (
@@ -145,7 +149,7 @@ def _process_file(
sections=[
Section(link=all_metadata.get("link"), text=file_content_raw.strip())
],
source=DocumentSource.FILE,
source=source_type or DocumentSource.FILE,
semantic_identifier=file_display_name,
title=title,
doc_updated_at=final_time_updated,

View File

@@ -0,0 +1,182 @@
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import List
import requests
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.logger import setup_logger
logger = setup_logger()
_FIREFLIES_ID_PREFIX = "FIREFLIES_"
_FIREFLIES_API_URL = "https://api.fireflies.ai/graphql"
_FIREFLIES_TRANSCRIPT_QUERY_SIZE = 50 # Max page size is 50
_FIREFLIES_API_QUERY = """
query Transcripts($fromDate: DateTime, $toDate: DateTime, $limit: Int!, $skip: Int!) {
transcripts(fromDate: $fromDate, toDate: $toDate, limit: $limit, skip: $skip) {
id
title
host_email
participants
date
transcript_url
sentences {
text
speaker_name
}
}
}
"""
def _create_doc_from_transcript(transcript: dict) -> Document | None:
meeting_text = ""
sentences = transcript.get("sentences", [])
if sentences:
for sentence in sentences:
meeting_text += sentence.get("speaker_name") or "Unknown Speaker"
meeting_text += ": " + sentence.get("text", "") + "\n\n"
else:
return None
meeting_link = transcript["transcript_url"]
fireflies_id = _FIREFLIES_ID_PREFIX + transcript["id"]
meeting_title = transcript["title"] or "No Title"
meeting_date_unix = transcript["date"]
meeting_date = datetime.fromtimestamp(meeting_date_unix / 1000, tz=timezone.utc)
meeting_host_email = transcript["host_email"]
host_email_user_info = [BasicExpertInfo(email=meeting_host_email)]
meeting_participants_email_list = []
for participant in transcript.get("participants", []):
if participant != meeting_host_email and participant:
meeting_participants_email_list.append(BasicExpertInfo(email=participant))
return Document(
id=fireflies_id,
sections=[
Section(
link=meeting_link,
text=meeting_text,
)
],
source=DocumentSource.FIREFLIES,
semantic_identifier=meeting_title,
metadata={},
doc_updated_at=meeting_date,
primary_owners=host_email_user_info,
secondary_owners=meeting_participants_email_list,
)
class FirefliesConnector(PollConnector, LoadConnector):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
def load_credentials(self, credentials: dict[str, str]) -> None:
api_key = credentials.get("fireflies_api_key")
if not isinstance(api_key, str):
raise ConnectorMissingCredentialError(
"The Fireflies API key must be a string"
)
self.api_key = api_key
return None
def _fetch_transcripts(
self, start_datetime: str | None = None, end_datetime: str | None = None
) -> Iterator[List[dict]]:
if self.api_key is None:
raise ConnectorMissingCredentialError("Missing API key")
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + self.api_key,
}
skip = 0
variables: dict[str, int | str] = {
"limit": _FIREFLIES_TRANSCRIPT_QUERY_SIZE,
}
if start_datetime:
variables["fromDate"] = start_datetime
if end_datetime:
variables["toDate"] = end_datetime
while True:
variables["skip"] = skip
response = requests.post(
_FIREFLIES_API_URL,
headers=headers,
json={"query": _FIREFLIES_API_QUERY, "variables": variables},
)
response.raise_for_status()
if response.status_code == 204:
break
recieved_transcripts = response.json()
parsed_transcripts = recieved_transcripts.get("data", {}).get(
"transcripts", []
)
yield parsed_transcripts
if len(parsed_transcripts) < _FIREFLIES_TRANSCRIPT_QUERY_SIZE:
break
skip += _FIREFLIES_TRANSCRIPT_QUERY_SIZE
def _process_transcripts(
self, start: str | None = None, end: str | None = None
) -> GenerateDocumentsOutput:
doc_batch: List[Document] = []
for transcript_batch in self._fetch_transcripts(start, end):
for transcript in transcript_batch:
if doc := _create_doc_from_transcript(transcript):
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
return self._process_transcripts()
def poll_source(
self, start_unixtime: SecondsSinceUnixEpoch, end_unixtime: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.fromtimestamp(
start_unixtime, tz=timezone.utc
).strftime("%Y-%m-%dT%H:%M:%S.000Z")
end_datetime = datetime.fromtimestamp(end_unixtime, tz=timezone.utc).strftime(
"%Y-%m-%dT%H:%M:%S.000Z"
)
yield from self._process_transcripts(start_datetime, end_datetime)

View File

@@ -0,0 +1,239 @@
import json
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import List
import requests
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.html_utils import parse_html_page_basic
from danswer.utils.logger import setup_logger
logger = setup_logger()
_FRESHDESK_ID_PREFIX = "FRESHDESK_"
_TICKET_FIELDS_TO_INCLUDE = {
"fr_escalated",
"spam",
"priority",
"source",
"status",
"type",
"is_escalated",
"tags",
"nr_due_by",
"nr_escalated",
"cc_emails",
"fwd_emails",
"reply_cc_emails",
"ticket_cc_emails",
"support_email",
"to_emails",
}
_SOURCE_NUMBER_TYPE_MAP: dict[int, str] = {
1: "Email",
2: "Portal",
3: "Phone",
7: "Chat",
9: "Feedback Widget",
10: "Outbound Email",
}
_PRIORITY_NUMBER_TYPE_MAP: dict[int, str] = {
1: "low",
2: "medium",
3: "high",
4: "urgent",
}
_STATUS_NUMBER_TYPE_MAP: dict[int, str] = {
2: "open",
3: "pending",
4: "resolved",
5: "closed",
}
def _create_metadata_from_ticket(ticket: dict) -> dict:
metadata: dict[str, str | list[str]] = {}
# Combine all emails into a list so there are no repeated emails
email_data: set[str] = set()
for key, value in ticket.items():
# Skip fields that aren't useful for embedding
if key not in _TICKET_FIELDS_TO_INCLUDE:
continue
# Skip empty fields
if not value or value == "[]":
continue
# Convert strings or lists to strings
stringified_value: str | list[str]
if isinstance(value, list):
stringified_value = [str(item) for item in value]
else:
stringified_value = str(value)
if "email" in key:
if isinstance(stringified_value, list):
email_data.update(stringified_value)
else:
email_data.add(stringified_value)
else:
metadata[key] = stringified_value
if email_data:
metadata["emails"] = list(email_data)
# Convert source numbers to human-parsable string
if source_number := ticket.get("source"):
metadata["source"] = _SOURCE_NUMBER_TYPE_MAP.get(
source_number, "Unknown Source Type"
)
# Convert priority numbers to human-parsable string
if priority_number := ticket.get("priority"):
metadata["priority"] = _PRIORITY_NUMBER_TYPE_MAP.get(
priority_number, "Unknown Priority"
)
# Convert status to human-parsable string
if status_number := ticket.get("status"):
metadata["status"] = _STATUS_NUMBER_TYPE_MAP.get(
status_number, "Unknown Status"
)
due_by = datetime.fromisoformat(ticket["due_by"].replace("Z", "+00:00"))
metadata["overdue"] = str(datetime.now(timezone.utc) > due_by)
return metadata
def _create_doc_from_ticket(ticket: dict, domain: str) -> Document:
# Use the ticket description as the text
text = f"Ticket description: {parse_html_page_basic(ticket.get('description_text', ''))}"
metadata = _create_metadata_from_ticket(ticket)
# This is also used in the ID because it is more unique than the just the ticket ID
link = f"https://{domain}.freshdesk.com/helpdesk/tickets/{ticket['id']}"
return Document(
id=_FRESHDESK_ID_PREFIX + link,
sections=[
Section(
link=link,
text=text,
)
],
source=DocumentSource.FRESHDESK,
semantic_identifier=ticket["subject"],
metadata=metadata,
doc_updated_at=datetime.fromisoformat(
ticket["updated_at"].replace("Z", "+00:00")
),
)
class FreshdeskConnector(PollConnector, LoadConnector):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
def load_credentials(self, credentials: dict[str, str | int]) -> None:
api_key = credentials.get("freshdesk_api_key")
domain = credentials.get("freshdesk_domain")
password = credentials.get("freshdesk_password")
if not all(isinstance(cred, str) for cred in [domain, api_key, password]):
raise ConnectorMissingCredentialError(
"All Freshdesk credentials must be strings"
)
self.api_key = str(api_key)
self.domain = str(domain)
self.password = str(password)
def _fetch_tickets(
self, start: datetime | None = None, end: datetime | None = None
) -> Iterator[List[dict]]:
"""
'end' is not currently used, so we may double fetch tickets created after the indexing
starts but before the actual call is made.
To use 'end' would require us to use the search endpoint but it has limitations,
namely having to fetch all IDs and then individually fetch each ticket because there is no
'include' field available for this endpoint:
https://developers.freshdesk.com/api/#filter_tickets
"""
if self.api_key is None or self.domain is None or self.password is None:
raise ConnectorMissingCredentialError("freshdesk")
base_url = f"https://{self.domain}.freshdesk.com/api/v2/tickets"
params: dict[str, int | str] = {
"include": "description",
"per_page": 50,
"page": 1,
}
if start:
params["updated_since"] = start.isoformat()
while True:
response = requests.get(
base_url, auth=(self.api_key, self.password), params=params
)
response.raise_for_status()
if response.status_code == 204:
break
tickets = json.loads(response.content)
logger.info(
f"Fetched {len(tickets)} tickets from Freshdesk API (Page {params['page']})"
)
yield tickets
if len(tickets) < int(params["per_page"]):
break
params["page"] = int(params["page"]) + 1
def _process_tickets(
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
doc_batch: List[Document] = []
for ticket_batch in self._fetch_tickets(start, end):
for ticket in ticket_batch:
doc_batch.append(_create_doc_from_ticket(ticket, self.domain))
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
return self._process_tickets()
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
yield from self._process_tickets(start_datetime, end_datetime)

View File

@@ -1,283 +1,360 @@
import re
import time
from base64 import urlsafe_b64decode
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import Dict
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient import discovery # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.gmail.connector_auth import (
get_gmail_creds_for_authorized_user,
)
from danswer.connectors.gmail.connector_auth import (
get_gmail_creds_for_service_account,
)
from danswer.connectors.gmail.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
from danswer.connectors.gmail.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.connectors.gmail.constants import (
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
from danswer.connectors.google_utils.google_auth import get_google_creds
from danswer.connectors.google_utils.google_utils import execute_paginated_retrieval
from danswer.connectors.google_utils.resources import get_admin_service
from danswer.connectors.google_utils.resources import get_gmail_service
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from danswer.connectors.google_utils.shared_constants import MISSING_SCOPES_ERROR_STR
from danswer.connectors.google_utils.shared_constants import ONYX_SCOPE_INSTRUCTIONS
from danswer.connectors.google_utils.shared_constants import SLIM_BATCH_SIZE
from danswer.connectors.google_utils.shared_constants import USER_FIELDS
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.models import SlimDocument
from danswer.utils.logger import setup_logger
from danswer.utils.retry_wrapper import retry_builder
logger = setup_logger()
# This is for the initial list call to get the thread ids
THREAD_LIST_FIELDS = "nextPageToken, threads(id)"
def _execute_with_retry(request: Any) -> Any:
max_attempts = 10
attempt = 0
# These are the fields to retrieve using the ID from the initial list call
PARTS_FIELDS = "parts(body(data), mimeType)"
PAYLOAD_FIELDS = f"payload(headers, {PARTS_FIELDS})"
MESSAGES_FIELDS = f"messages(id, {PAYLOAD_FIELDS})"
THREADS_FIELDS = f"threads(id, {MESSAGES_FIELDS})"
THREAD_FIELDS = f"id, {MESSAGES_FIELDS}"
while attempt < max_attempts:
# Note for reasons unknown, the Google API will sometimes return a 429
# and even after waiting the retry period, it will return another 429.
# It could be due to a few possibilities:
# 1. Other things are also requesting from the Gmail API with the same key
# 2. It's a rolling rate limit so the moment we get some amount of requests cleared, we hit it again very quickly
# 3. The retry-after has a maximum and we've already hit the limit for the day
# or it's something else...
try:
return request.execute()
except HttpError as error:
attempt += 1
EMAIL_FIELDS = [
"cc",
"bcc",
"from",
"to",
]
if error.resp.status == 429:
# Attempt to get 'Retry-After' from headers
retry_after = error.resp.get("Retry-After")
if retry_after:
sleep_time = int(retry_after)
else:
# Extract 'Retry after' timestamp from error message
match = re.search(
r"Retry after (\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+Z)",
str(error),
add_retries = retry_builder(tries=50, max_delay=30)
def _build_time_range_query(
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
) -> str | None:
query = ""
if time_range_start is not None and time_range_start != 0:
query += f"after:{int(time_range_start)}"
if time_range_end is not None and time_range_end != 0:
query += f" before:{int(time_range_end)}"
query = query.strip()
if len(query) == 0:
return None
return query
def _clean_email_and_extract_name(email: str) -> tuple[str, str | None]:
email = email.strip()
if "<" in email and ">" in email:
# Handle format: "Display Name <email@domain.com>"
display_name = email[: email.find("<")].strip()
email_address = email[email.find("<") + 1 : email.find(">")].strip()
return email_address, display_name if display_name else None
else:
# Handle plain email address
return email.strip(), None
def _get_owners_from_emails(emails: dict[str, str | None]) -> list[BasicExpertInfo]:
owners = []
for email, names in emails.items():
if names:
name_parts = names.split(" ")
first_name = " ".join(name_parts[:-1])
last_name = name_parts[-1]
else:
first_name = None
last_name = None
owners.append(
BasicExpertInfo(email=email, first_name=first_name, last_name=last_name)
)
return owners
def _get_message_body(payload: dict[str, Any]) -> str:
parts = payload.get("parts", [])
message_body = ""
for part in parts:
mime_type = part.get("mimeType")
body = part.get("body")
if mime_type == "text/plain" and body:
data = body.get("data", "")
text = urlsafe_b64decode(data).decode()
message_body += text
return message_body
def message_to_section(message: Dict[str, Any]) -> tuple[Section, dict[str, str]]:
link = f"https://mail.google.com/mail/u/0/#inbox/{message['id']}"
payload = message.get("payload", {})
headers = payload.get("headers", [])
metadata: dict[str, Any] = {}
for header in headers:
name = header.get("name").lower()
value = header.get("value")
if name in EMAIL_FIELDS:
metadata[name] = value
if name == "subject":
metadata["subject"] = value
if name == "date":
metadata["updated_at"] = value
if labels := message.get("labelIds"):
metadata["labels"] = labels
message_data = ""
for name, value in metadata.items():
# updated at isnt super useful for the llm
if name != "updated_at":
message_data += f"{name}: {value}\n"
message_body_text: str = _get_message_body(payload)
return Section(link=link, text=message_body_text + message_data), metadata
def thread_to_document(full_thread: Dict[str, Any]) -> Document | None:
all_messages = full_thread.get("messages", [])
if not all_messages:
return None
sections = []
semantic_identifier = ""
updated_at = None
from_emails: dict[str, str | None] = {}
other_emails: dict[str, str | None] = {}
for message in all_messages:
section, message_metadata = message_to_section(message)
sections.append(section)
for name, value in message_metadata.items():
if name in EMAIL_FIELDS:
email, display_name = _clean_email_and_extract_name(value)
if name == "from":
from_emails[email] = (
display_name if not from_emails.get(email) else None
)
else:
other_emails[email] = (
display_name if not other_emails.get(email) else None
)
if match:
retry_after_timestamp = match.group(1)
retry_after_dt = datetime.strptime(
retry_after_timestamp, "%Y-%m-%dT%H:%M:%S.%fZ"
).replace(tzinfo=timezone.utc)
current_time = datetime.now(timezone.utc)
sleep_time = max(
int((retry_after_dt - current_time).total_seconds()),
0,
)
else:
logger.error(
f"No Retry-After header or timestamp found in error message: {error}"
)
sleep_time = 60
sleep_time += 3 # Add a buffer to be safe
# If we haven't set the semantic identifier yet, set it to the subject of the first message
if not semantic_identifier:
semantic_identifier = message_metadata.get("subject", "")
logger.info(
f"Rate limit exceeded. Attempt {attempt}/{max_attempts}. Sleeping for {sleep_time} seconds."
)
time.sleep(sleep_time)
if message_metadata.get("updated_at"):
updated_at = message_metadata.get("updated_at")
else:
raise
updated_at_datetime = None
if updated_at:
updated_at_datetime = time_str_to_utc(updated_at)
# If we've exhausted all attempts
raise Exception(f"Failed to execute request after {max_attempts} attempts")
id = full_thread.get("id")
if not id:
raise ValueError("Thread ID is required")
primary_owners = _get_owners_from_emails(from_emails)
secondary_owners = _get_owners_from_emails(other_emails)
return Document(
id=id,
semantic_identifier=semantic_identifier,
sections=sections,
source=DocumentSource.GMAIL,
# This is used to perform permission sync
primary_owners=primary_owners,
secondary_owners=secondary_owners,
doc_updated_at=updated_at_datetime,
# Not adding emails to metadata because it's already in the sections
metadata={},
)
class GmailConnector(LoadConnector, PollConnector):
class GmailConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
self.creds: OAuthCredentials | ServiceAccountCredentials | None = None
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
self._primary_admin_email: str | None = None
@property
def primary_admin_email(self) -> str:
if self._primary_admin_email is None:
raise RuntimeError(
"Primary admin email missing, "
"should not call this property "
"before calling load_credentials"
)
return self._primary_admin_email
@property
def google_domain(self) -> str:
if self._primary_admin_email is None:
raise RuntimeError(
"Primary admin email missing, "
"should not call this property "
"before calling load_credentials"
)
return self._primary_admin_email.split("@")[-1]
@property
def creds(self) -> OAuthCredentials | ServiceAccountCredentials:
if self._creds is None:
raise RuntimeError(
"Creds missing, "
"should not call this property "
"before calling load_credentials"
)
return self._creds
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
"""Checks for two different types of credentials.
(1) A credential which holds a token acquired via a user going thorugh
the Google OAuth flow.
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.
"""
creds: OAuthCredentials | ServiceAccountCredentials | None = None
new_creds_dict = None
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
access_token_json_str = cast(
str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]
)
creds = get_gmail_creds_for_authorized_user(
token_json_str=access_token_json_str
)
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
self._primary_admin_email = primary_admin_email
# tell caller to update token stored in DB if it has changed
# (e.g. the token has been refreshed)
new_creds_json_str = creds.to_json() if creds else ""
if new_creds_json_str != access_token_json_str:
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
if GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
service_account_key_json_str = credentials[
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
]
creds = get_gmail_creds_for_service_account(
service_account_key_json_str=service_account_key_json_str
)
# "Impersonate" a user if one is specified
delegated_user_email = cast(
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
)
if delegated_user_email:
creds = creds.with_subject(delegated_user_email) if creds else None # type: ignore
if creds is None:
raise PermissionError(
"Unable to access Gmail - unknown credential structure."
)
self.creds = creds
self._creds, new_creds_dict = get_google_creds(
credentials=credentials,
source=DocumentSource.GMAIL,
)
return new_creds_dict
def _get_email_body(self, payload: dict[str, Any]) -> str:
parts = payload.get("parts", [])
email_body = ""
for part in parts:
mime_type = part.get("mimeType")
body = part.get("body")
if mime_type == "text/plain":
data = body.get("data", "")
text = urlsafe_b64decode(data).decode()
email_body += text
return email_body
def _get_all_user_emails(self) -> list[str]:
admin_service = get_admin_service(self.creds, self.primary_admin_email)
emails = []
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
):
if email := user.get("primaryEmail"):
emails.append(email)
return emails
def _email_to_document(self, full_email: Dict[str, Any]) -> Document:
email_id = full_email["id"]
payload = full_email["payload"]
headers = payload.get("headers")
labels = full_email.get("labelIds", [])
metadata = {}
if headers:
for header in headers:
name = header.get("name").lower()
value = header.get("value")
if name in ["from", "to", "subject", "date", "cc", "bcc"]:
metadata[name] = value
email_data = ""
for name, value in metadata.items():
email_data += f"{name}: {value}\n"
metadata["labels"] = labels
logger.debug(f"{email_data}")
email_body_text: str = self._get_email_body(payload)
date_str = metadata.get("date")
email_updated_at = time_str_to_utc(date_str) if date_str else None
link = f"https://mail.google.com/mail/u/0/#inbox/{email_id}"
return Document(
id=email_id,
sections=[Section(link=link, text=email_data + email_body_text)],
source=DocumentSource.GMAIL,
title=metadata.get("subject"),
semantic_identifier=metadata.get("subject", "Untitled Email"),
doc_updated_at=email_updated_at,
metadata=metadata,
)
@staticmethod
def _build_time_range_query(
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
) -> str | None:
query = ""
if time_range_start is not None and time_range_start != 0:
query += f"after:{int(time_range_start)}"
if time_range_end is not None and time_range_end != 0:
query += f" before:{int(time_range_end)}"
query = query.strip()
if len(query) == 0:
return None
return query
def _fetch_mails_from_gmail(
def _fetch_threads(
self,
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
if self.creds is None:
raise PermissionError("Not logged into Gmail")
page_token = ""
query = GmailConnector._build_time_range_query(time_range_start, time_range_end)
service = discovery.build("gmail", "v1", credentials=self.creds)
while page_token is not None:
result = _execute_with_retry(
service.users()
.messages()
.list(
userId="me",
pageToken=page_token,
q=query,
maxResults=self.batch_size,
query = _build_time_range_query(time_range_start, time_range_end)
doc_batch = []
for user_email in self._get_all_user_emails():
gmail_service = get_gmail_service(self.creds, user_email)
for thread in execute_paginated_retrieval(
retrieval_function=gmail_service.users().threads().list,
list_key="threads",
userId=user_email,
fields=THREAD_LIST_FIELDS,
q=query,
):
full_threads = execute_paginated_retrieval(
retrieval_function=gmail_service.users().threads().get,
list_key=None,
userId=user_email,
fields=THREAD_FIELDS,
id=thread["id"],
)
)
page_token = result.get("nextPageToken")
messages = result.get("messages", [])
doc_batch = []
for message in messages:
message_id = message["id"]
msg = _execute_with_retry(
service.users()
.messages()
.get(userId="me", id=message_id, format="full")
)
doc = self._email_to_document(msg)
# full_threads is an iterator containing a single thread
# so we need to convert it to a list and grab the first element
full_thread = list(full_threads)[0]
doc = thread_to_document(full_thread)
if doc is None:
continue
doc_batch.append(doc)
if len(doc_batch) > 0:
yield doc_batch
if len(doc_batch) > self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
def _fetch_slim_threads(
self,
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
query = _build_time_range_query(time_range_start, time_range_end)
doc_batch = []
for user_email in self._get_all_user_emails():
gmail_service = get_gmail_service(self.creds, user_email)
for thread in execute_paginated_retrieval(
retrieval_function=gmail_service.users().threads().list,
list_key="threads",
userId=user_email,
fields=THREAD_LIST_FIELDS,
q=query,
):
doc_batch.append(
SlimDocument(
id=thread["id"],
perm_sync_data={"user_email": user_email},
)
)
if len(doc_batch) > SLIM_BATCH_SIZE:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
yield from self._fetch_mails_from_gmail()
try:
yield from self._fetch_threads()
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
yield from self._fetch_mails_from_gmail(start, end)
try:
yield from self._fetch_threads(start, end)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
try:
yield from self._fetch_slim_threads(start, end)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
if __name__ == "__main__":
import json
import os
service_account_json_path = os.environ.get("GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH")
if not service_account_json_path:
raise ValueError(
"Please set GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH environment variable"
)
with open(service_account_json_path) as f:
creds = json.load(f)
credentials_dict = {
DB_CREDENTIALS_DICT_TOKEN_KEY: json.dumps(creds),
}
delegated_user = os.environ.get("GMAIL_DELEGATED_USER")
if delegated_user:
credentials_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user
connector = GmailConnector()
connector.load_credentials(
json.loads(credentials_dict[DB_CREDENTIALS_DICT_TOKEN_KEY])
)
document_batch_generator = connector.load_from_state()
for document_batch in document_batch_generator:
print(document_batch)
break
pass

View File

@@ -1,197 +0,0 @@
import json
from typing import cast
from urllib.parse import parse_qs
from urllib.parse import ParseResult
from urllib.parse import urlparse
from google.auth.transport.requests import Request # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from sqlalchemy.orm import Session
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import KV_CRED_KEY
from danswer.configs.constants import KV_GMAIL_CRED_KEY
from danswer.configs.constants import KV_GMAIL_SERVICE_ACCOUNT_KEY
from danswer.connectors.gmail.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
from danswer.connectors.gmail.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.connectors.gmail.constants import (
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.gmail.constants import SCOPES
from danswer.db.credentials import update_credential_json
from danswer.db.models import User
from danswer.key_value_store.factory import get_kv_store
from danswer.server.documents.models import CredentialBase
from danswer.server.documents.models import GoogleAppCredentials
from danswer.server.documents.models import GoogleServiceAccountKey
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _build_frontend_gmail_redirect() -> str:
return f"{WEB_DOMAIN}/admin/connectors/gmail/auth/callback"
def get_gmail_creds_for_authorized_user(
token_json_str: str,
) -> OAuthCredentials | None:
creds_json = json.loads(token_json_str)
creds = OAuthCredentials.from_authorized_user_info(creds_json, SCOPES)
if creds.valid:
return creds
if creds.expired and creds.refresh_token:
try:
creds.refresh(Request())
if creds.valid:
logger.notice("Refreshed Gmail tokens.")
return creds
except Exception as e:
logger.exception(f"Failed to refresh gmail access token due to: {e}")
return None
return None
def get_gmail_creds_for_service_account(
service_account_key_json_str: str,
) -> ServiceAccountCredentials | None:
service_account_key = json.loads(service_account_key_json_str)
creds = ServiceAccountCredentials.from_service_account_info(
service_account_key, scopes=SCOPES
)
if not creds.valid or not creds.expired:
creds.refresh(Request())
return creds if creds.valid else None
def verify_csrf(credential_id: int, state: str) -> None:
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
if csrf != state:
raise PermissionError(
"State from Gmail Connector callback does not match expected"
)
def get_gmail_auth_url(credential_id: int) -> str:
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
scopes=SCOPES,
redirect_uri=_build_frontend_gmail_redirect(),
)
auth_url, _ = flow.authorization_url(prompt="consent")
parsed_url = cast(ParseResult, urlparse(auth_url))
params = parse_qs(parsed_url.query)
get_kv_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
return str(auth_url)
def get_auth_url(credential_id: int) -> str:
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
scopes=SCOPES,
redirect_uri=_build_frontend_gmail_redirect(),
)
auth_url, _ = flow.authorization_url(prompt="consent")
parsed_url = cast(ParseResult, urlparse(auth_url))
params = parse_qs(parsed_url.query)
get_kv_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
return str(auth_url)
def update_gmail_credential_access_tokens(
auth_code: str,
credential_id: int,
user: User,
db_session: Session,
) -> OAuthCredentials | None:
app_credentials = get_google_app_gmail_cred()
flow = InstalledAppFlow.from_client_config(
app_credentials.model_dump(),
scopes=SCOPES,
redirect_uri=_build_frontend_gmail_redirect(),
)
flow.fetch_token(code=auth_code)
creds = flow.credentials
token_json_str = creds.to_json()
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str}
if not update_credential_json(credential_id, new_creds_dict, user, db_session):
return None
return creds
def build_service_account_creds(
delegated_user_email: str | None = None,
) -> CredentialBase:
service_account_key = get_gmail_service_account_key()
credential_dict = {
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: service_account_key.json(),
}
if delegated_user_email:
credential_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user_email
return CredentialBase(
source=DocumentSource.GMAIL,
credential_json=credential_dict,
admin_public=True,
)
def get_google_app_gmail_cred() -> GoogleAppCredentials:
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
return GoogleAppCredentials(**json.loads(creds_str))
def upsert_google_app_gmail_cred(app_credentials: GoogleAppCredentials) -> None:
get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True)
def delete_google_app_gmail_cred() -> None:
get_kv_store().delete(KV_GMAIL_CRED_KEY)
def get_gmail_service_account_key() -> GoogleServiceAccountKey:
creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
return GoogleServiceAccountKey(**json.loads(creds_str))
def upsert_gmail_service_account_key(
service_account_key: GoogleServiceAccountKey,
) -> None:
get_kv_store().store(
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
get_kv_store().store(
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
def delete_gmail_service_account_key() -> None:
get_kv_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
def delete_service_account_key() -> None:
get_kv_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)

View File

@@ -1,4 +0,0 @@
DB_CREDENTIALS_DICT_TOKEN_KEY = "gmail_tokens"
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "gmail_service_account_key"
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "gmail_delegated_user"
SCOPES = ["https://www.googleapis.com/auth/gmail.readonly"]

View File

@@ -1,556 +1,400 @@
import io
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
from datetime import datetime
from datetime import timezone
from enum import Enum
from itertools import chain
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient import discovery # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import GOOGLE_DRIVE_FOLLOW_SHORTCUTS
from danswer.configs.app_configs import GOOGLE_DRIVE_INCLUDE_SHARED
from danswer.configs.app_configs import GOOGLE_DRIVE_ONLY_ORG_PUBLIC
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.connectors.google_drive.connector_auth import get_google_drive_creds
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
from danswer.connectors.google_drive.doc_conversion import build_slim_document
from danswer.connectors.google_drive.doc_conversion import (
convert_drive_item_to_document,
)
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
from danswer.connectors.google_drive.file_retrieval import crawl_folders_for_files
from danswer.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
from danswer.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.connectors.google_utils.google_auth import get_google_creds
from danswer.connectors.google_utils.google_utils import execute_paginated_retrieval
from danswer.connectors.google_utils.resources import get_admin_service
from danswer.connectors.google_utils.resources import get_drive_service
from danswer.connectors.google_utils.resources import get_google_docs_service
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from danswer.connectors.google_utils.shared_constants import MISSING_SCOPES_ERROR_STR
from danswer.connectors.google_utils.shared_constants import ONYX_SCOPE_INSTRUCTIONS
from danswer.connectors.google_utils.shared_constants import SCOPE_DOC_URL
from danswer.connectors.google_utils.shared_constants import SLIM_BATCH_SIZE
from danswer.connectors.google_utils.shared_constants import USER_FIELDS
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import docx_to_text
from danswer.file_processing.extract_file_text import pptx_to_text
from danswer.file_processing.extract_file_text import read_pdf_file
from danswer.file_processing.unstructured import get_unstructured_api_key
from danswer.file_processing.unstructured import unstructured_to_text
from danswer.utils.batching import batch_generator
from danswer.connectors.interfaces import SlimConnector
from danswer.utils.logger import setup_logger
from danswer.utils.retry_wrapper import retry_builder
logger = setup_logger()
DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder"
DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut"
UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now
# TODO: Improve this by using the batch utility: https://googleapis.github.io/google-api-python-client/docs/batch.html
# All file retrievals could be batched and made at once
class GDriveMimeType(str, Enum):
DOC = "application/vnd.google-apps.document"
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
PDF = "application/pdf"
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
PPT = "application/vnd.google-apps.presentation"
POWERPOINT = (
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
)
PLAIN_TEXT = "text/plain"
MARKDOWN = "text/markdown"
def _extract_str_list_from_comma_str(string: str | None) -> list[str]:
if not string:
return []
return [s.strip() for s in string.split(",") if s.strip()]
GoogleDriveFileType = dict[str, Any]
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. Trying to combat here by adding a very
# long retry period (~20 minutes of trying every minute)
add_retries = retry_builder(tries=50, max_delay=30)
def _extract_ids_from_urls(urls: list[str]) -> list[str]:
return [url.split("/")[-1] for url in urls]
def _run_drive_file_query(
service: discovery.Resource,
query: str,
continue_on_failure: bool,
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
batch_size: int = INDEX_BATCH_SIZE,
) -> Iterator[GoogleDriveFileType]:
next_page_token = ""
while next_page_token is not None:
logger.debug(f"Running Google Drive fetch with query: {query}")
results = add_retries(
lambda: (
service.files()
.list(
corpora="allDrives"
if include_shared
else "user", # needed to search through shared drives
pageSize=batch_size,
supportsAllDrives=include_shared,
includeItemsFromAllDrives=include_shared,
fields=(
"nextPageToken, files(mimeType, id, name, permissions, "
"modifiedTime, webViewLink, shortcutDetails)"
),
pageToken=next_page_token,
q=query,
)
.execute()
)
)()
next_page_token = results.get("nextPageToken")
files = results["files"]
for file in files:
if follow_shortcuts and "shortcutDetails" in file:
try:
file_shortcut_points_to = add_retries(
lambda: (
service.files()
.get(
fileId=file["shortcutDetails"]["targetId"],
supportsAllDrives=include_shared,
fields="mimeType, id, name, modifiedTime, webViewLink, permissions, shortcutDetails",
)
.execute()
)
)()
yield file_shortcut_points_to
except HttpError:
logger.error(
f"Failed to follow shortcut with details: {file['shortcutDetails']}"
)
if continue_on_failure:
continue
raise
else:
yield file
def _get_folder_id(
service: discovery.Resource,
parent_id: str,
folder_name: str,
include_shared: bool,
follow_shortcuts: bool,
) -> str | None:
"""
Get the ID of a folder given its name and the ID of its parent folder.
"""
query = f"'{parent_id}' in parents and name='{folder_name}' and "
if follow_shortcuts:
query += f"(mimeType='{DRIVE_FOLDER_TYPE}' or mimeType='{DRIVE_SHORTCUT_TYPE}')"
else:
query += f"mimeType='{DRIVE_FOLDER_TYPE}'"
# TODO: support specifying folder path in shared drive rather than just `My Drive`
results = add_retries(
lambda: (
service.files()
.list(
q=query,
spaces="drive",
fields="nextPageToken, files(id, name, shortcutDetails)",
supportsAllDrives=include_shared,
includeItemsFromAllDrives=include_shared,
)
.execute()
)
)()
items = results.get("files", [])
folder_id = None
if items:
if follow_shortcuts and "shortcutDetails" in items[0]:
folder_id = items[0]["shortcutDetails"]["targetId"]
else:
folder_id = items[0]["id"]
return folder_id
def _get_folders(
service: discovery.Resource,
continue_on_failure: bool,
folder_id: str | None = None, # if specified, only fetches files within this folder
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
batch_size: int = INDEX_BATCH_SIZE,
) -> Iterator[GoogleDriveFileType]:
query = f"mimeType = '{DRIVE_FOLDER_TYPE}' "
if follow_shortcuts:
query = "(" + query + f" or mimeType = '{DRIVE_SHORTCUT_TYPE}'" + ") "
if folder_id:
query += f"and '{folder_id}' in parents "
query = query.rstrip() # remove the trailing space(s)
for file in _run_drive_file_query(
service=service,
query=query,
continue_on_failure=continue_on_failure,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
):
# Need to check this since file may have been a target of a shortcut
# and not necessarily a folder
if file["mimeType"] == DRIVE_FOLDER_TYPE:
yield file
else:
pass
def _get_files(
service: discovery.Resource,
continue_on_failure: bool,
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
folder_id: str | None = None, # if specified, only fetches files within this folder
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
batch_size: int = INDEX_BATCH_SIZE,
) -> Iterator[GoogleDriveFileType]:
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' "
if time_range_start is not None:
time_start = datetime.utcfromtimestamp(time_range_start).isoformat() + "Z"
query += f"and modifiedTime >= '{time_start}' "
if time_range_end is not None:
time_stop = datetime.utcfromtimestamp(time_range_end).isoformat() + "Z"
query += f"and modifiedTime <= '{time_stop}' "
if folder_id:
query += f"and '{folder_id}' in parents "
query = query.rstrip() # remove the trailing space(s)
files = _run_drive_file_query(
service=service,
query=query,
continue_on_failure=continue_on_failure,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
def _convert_single_file(
creds: Any, primary_admin_email: str, file: dict[str, Any]
) -> Any:
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
user_drive_service = get_drive_service(creds, user_email=user_email)
docs_service = get_google_docs_service(creds, user_email=user_email)
return convert_drive_item_to_document(
file=file,
drive_service=user_drive_service,
docs_service=docs_service,
)
return files
def _process_files_batch(
files: list[GoogleDriveFileType], convert_func: Callable, batch_size: int
) -> GenerateDocumentsOutput:
doc_batch = []
with ThreadPoolExecutor(max_workers=min(16, len(files))) as executor:
for doc in executor.map(convert_func, files):
if doc:
doc_batch.append(doc)
if len(doc_batch) >= batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
def get_all_files_batched(
service: discovery.Resource,
continue_on_failure: bool,
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
batch_size: int = INDEX_BATCH_SIZE,
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
folder_id: str | None = None, # if specified, only fetches files within this folder
# if True, will fetch files in sub-folders of the specified folder ID.
# Only applies if folder_id is specified.
traverse_subfolders: bool = True,
folder_ids_traversed: list[str] | None = None,
) -> Iterator[list[GoogleDriveFileType]]:
"""Gets all files matching the criteria specified by the args from Google Drive
in batches of size `batch_size`.
"""
found_files = _get_files(
service=service,
continue_on_failure=continue_on_failure,
time_range_start=time_range_start,
time_range_end=time_range_end,
folder_id=folder_id,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
)
yield from batch_generator(
items=found_files,
batch_size=batch_size,
pre_batch_yield=lambda batch_files: logger.debug(
f"Parseable Documents in batch: {[file['name'] for file in batch_files]}"
),
)
if traverse_subfolders and folder_id is not None:
folder_ids_traversed = folder_ids_traversed or []
subfolders = _get_folders(
service=service,
folder_id=folder_id,
continue_on_failure=continue_on_failure,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
)
for subfolder in subfolders:
if subfolder["id"] not in folder_ids_traversed:
logger.info("Fetching all files in subfolder: " + subfolder["name"])
folder_ids_traversed.append(subfolder["id"])
yield from get_all_files_batched(
service=service,
continue_on_failure=continue_on_failure,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
time_range_start=time_range_start,
time_range_end=time_range_end,
folder_id=subfolder["id"],
traverse_subfolders=traverse_subfolders,
folder_ids_traversed=folder_ids_traversed,
)
else:
logger.debug(
"Skipping subfolder since already traversed: " + subfolder["name"]
)
def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
mime_type = file["mimeType"]
if mime_type not in set(item.value for item in GDriveMimeType):
# Unsupported file types can still have a title, finding this way is still useful
return UNSUPPORTED_FILE_TYPE_CONTENT
if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
GDriveMimeType.SPREADSHEET.value,
]:
export_mime_type = (
"text/plain"
if mime_type != GDriveMimeType.SPREADSHEET.value
else "text/csv"
)
return (
service.files()
.export(fileId=file["id"], mimeType=export_mime_type)
.execute()
.decode("utf-8")
)
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,
]:
return service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
if mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
GDriveMimeType.PDF.value,
]:
response = service.files().get_media(fileId=file["id"]).execute()
if get_unstructured_api_key():
return unstructured_to_text(
file=io.BytesIO(response), file_name=file.get("name", file["id"])
)
if mime_type == GDriveMimeType.WORD_DOC.value:
return docx_to_text(file=io.BytesIO(response))
elif mime_type == GDriveMimeType.PDF.value:
text, _ = read_pdf_file(file=io.BytesIO(response))
return text
elif mime_type == GDriveMimeType.POWERPOINT.value:
return pptx_to_text(file=io.BytesIO(response))
return UNSUPPORTED_FILE_TYPE_CONTENT
class GoogleDriveConnector(LoadConnector, PollConnector):
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
# optional list of folder paths e.g. "[My Folder/My Subfolder]"
# if specified, will only index files in these folders
folder_paths: list[str] | None = None,
include_shared_drives: bool = True,
shared_drive_urls: str | None = None,
include_my_drives: bool = True,
my_drive_emails: str | None = None,
shared_folder_urls: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
only_org_public: bool = GOOGLE_DRIVE_ONLY_ORG_PUBLIC,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
# OLD PARAMETERS
folder_paths: list[str] | None = None,
include_shared: bool | None = None,
follow_shortcuts: bool | None = None,
only_org_public: bool | None = None,
continue_on_failure: bool | None = None,
) -> None:
self.folder_paths = folder_paths or []
# Check for old input parameters
if (
folder_paths is not None
or include_shared is not None
or follow_shortcuts is not None
or only_org_public is not None
or continue_on_failure is not None
):
logger.exception(
"Google Drive connector received old input parameters. "
"Please visit the docs for help with the new setup: "
f"{SCOPE_DOC_URL}"
)
raise ValueError(
"Google Drive connector received old input parameters. "
"Please visit the docs for help with the new setup: "
f"{SCOPE_DOC_URL}"
)
if (
not include_shared_drives
and not include_my_drives
and not shared_folder_urls
):
raise ValueError(
"At least one of include_shared_drives, include_my_drives,"
" or shared_folder_urls must be true"
)
self.batch_size = batch_size
self.include_shared = include_shared
self.follow_shortcuts = follow_shortcuts
self.only_org_public = only_org_public
self.continue_on_failure = continue_on_failure
self.creds: OAuthCredentials | ServiceAccountCredentials | None = None
@staticmethod
def _process_folder_paths(
service: discovery.Resource,
folder_paths: list[str],
include_shared: bool,
follow_shortcuts: bool,
) -> list[str]:
"""['Folder/Sub Folder'] -> ['<FOLDER_ID>']"""
folder_ids: list[str] = []
for path in folder_paths:
folder_names = path.split("/")
parent_id = "root"
for folder_name in folder_names:
found_parent_id = _get_folder_id(
service=service,
parent_id=parent_id,
folder_name=folder_name,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
)
if found_parent_id is None:
raise ValueError(
(
f"Folder '{folder_name}' in path '{path}' "
"not found in Google Drive"
)
)
parent_id = found_parent_id
folder_ids.append(parent_id)
self.include_shared_drives = include_shared_drives
shared_drive_url_list = _extract_str_list_from_comma_str(shared_drive_urls)
self._requested_shared_drive_ids = set(
_extract_ids_from_urls(shared_drive_url_list)
)
return folder_ids
self.include_my_drives = include_my_drives
self._requested_my_drive_emails = set(
_extract_str_list_from_comma_str(my_drive_emails)
)
shared_folder_url_list = _extract_str_list_from_comma_str(shared_folder_urls)
self._requested_folder_ids = set(_extract_ids_from_urls(shared_folder_url_list))
self._primary_admin_email: str | None = None
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
self._retrieved_ids: set[str] = set()
@property
def primary_admin_email(self) -> str:
if self._primary_admin_email is None:
raise RuntimeError(
"Primary admin email missing, "
"should not call this property "
"before calling load_credentials"
)
return self._primary_admin_email
@property
def google_domain(self) -> str:
if self._primary_admin_email is None:
raise RuntimeError(
"Primary admin email missing, "
"should not call this property "
"before calling load_credentials"
)
return self._primary_admin_email.split("@")[-1]
@property
def creds(self) -> OAuthCredentials | ServiceAccountCredentials:
if self._creds is None:
raise RuntimeError(
"Creds missing, "
"should not call this property "
"before calling load_credentials"
)
return self._creds
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
"""Checks for two different types of credentials.
(1) A credential which holds a token acquired via a user going thorough
the Google OAuth flow.
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.
"""
creds, new_creds_dict = get_google_drive_creds(credentials)
self.creds = creds
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
self._primary_admin_email = primary_admin_email
self._creds, new_creds_dict = get_google_creds(
credentials=credentials,
source=DocumentSource.GOOGLE_DRIVE,
)
return new_creds_dict
def _fetch_docs_from_drive(
def _update_traversed_parent_ids(self, folder_id: str) -> None:
self._retrieved_ids.add(folder_id)
def _get_all_user_emails(self, admins_only: bool) -> list[str]:
admin_service = get_admin_service(
creds=self.creds,
user_email=self.primary_admin_email,
)
query = "isAdmin=true" if admins_only else "isAdmin=false"
emails = []
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
query=query,
):
if email := user.get("primaryEmail"):
emails.append(email)
return emails
def _get_all_drive_ids(self) -> set[str]:
primary_drive_service = get_drive_service(
creds=self.creds,
user_email=self.primary_admin_email,
)
all_drive_ids = set()
for drive in execute_paginated_retrieval(
retrieval_function=primary_drive_service.drives().list,
list_key="drives",
useDomainAdminAccess=True,
fields="drives(id)",
):
all_drive_ids.add(drive["id"])
return all_drive_ids
def _initialize_all_class_variables(self) -> None:
# Get all user emails
# Get admins first becuase they are more likely to have access to the most files
user_emails = [self.primary_admin_email]
for admins_only in [True, False]:
for email in self._get_all_user_emails(admins_only=admins_only):
if email not in user_emails:
user_emails.append(email)
self._all_org_emails = user_emails
self._all_drive_ids: set[str] = self._get_all_drive_ids()
# remove drive ids from the folder ids because they are queried differently
self._requested_folder_ids -= self._all_drive_ids
# Remove drive_ids that are not in the all_drive_ids and check them as folders instead
invalid_drive_ids = self._requested_shared_drive_ids - self._all_drive_ids
if invalid_drive_ids:
logger.warning(
f"Some shared drive IDs were not found. IDs: {invalid_drive_ids}"
)
logger.warning("Checking for folder access instead...")
self._requested_folder_ids.update(invalid_drive_ids)
if not self.include_shared_drives:
self._requested_shared_drive_ids = set()
elif not self._requested_shared_drive_ids:
self._requested_shared_drive_ids = self._all_drive_ids
def _impersonate_user_for_retrieval(
self,
user_email: str,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
drive_service = get_drive_service(self.creds, user_email)
if self.include_my_drives and (
not self._requested_my_drive_emails
or user_email in self._requested_my_drive_emails
):
yield from get_all_files_in_my_drive(
service=drive_service,
update_traversed_ids_func=self._update_traversed_parent_ids,
is_slim=is_slim,
start=start,
end=end,
)
remaining_drive_ids = self._requested_shared_drive_ids - self._retrieved_ids
for drive_id in remaining_drive_ids:
yield from get_files_in_shared_drive(
service=drive_service,
drive_id=drive_id,
is_slim=is_slim,
update_traversed_ids_func=self._update_traversed_parent_ids,
start=start,
end=end,
)
remaining_folders = self._requested_folder_ids - self._retrieved_ids
for folder_id in remaining_folders:
yield from crawl_folders_for_files(
service=drive_service,
parent_id=folder_id,
traversed_parent_ids=self._retrieved_ids,
update_traversed_ids_func=self._update_traversed_parent_ids,
start=start,
end=end,
)
def _fetch_drive_items(
self,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
self._initialize_all_class_variables()
# Process users in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=10) as executor:
future_to_email = {
executor.submit(
self._impersonate_user_for_retrieval, email, is_slim, start, end
): email
for email in self._all_org_emails
}
# Yield results as they complete
for future in as_completed(future_to_email):
yield from future.result()
remaining_folders = self._requested_folder_ids - self._retrieved_ids
if remaining_folders:
logger.warning(
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"
)
def _extract_docs_from_google_drive(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
if self.creds is None:
raise PermissionError("Not logged into Google Drive")
service = discovery.build("drive", "v3", credentials=self.creds)
folder_ids: Sequence[str | None] = self._process_folder_paths(
service, self.folder_paths, self.include_shared, self.follow_shortcuts
# Create a larger process pool for file conversion
convert_func = partial(
_convert_single_file, self.creds, self.primary_admin_email
)
if not folder_ids:
folder_ids = [None]
file_batches = chain(
*[
get_all_files_batched(
service=service,
continue_on_failure=self.continue_on_failure,
include_shared=self.include_shared,
follow_shortcuts=self.follow_shortcuts,
batch_size=self.batch_size,
time_range_start=start,
time_range_end=end,
folder_id=folder_id,
traverse_subfolders=True,
# Process files in larger batches
LARGE_BATCH_SIZE = self.batch_size * 4
files_to_process = []
# Gather the files into batches to be processed in parallel
for file in self._fetch_drive_items(is_slim=False, start=start, end=end):
files_to_process.append(file)
if len(files_to_process) >= LARGE_BATCH_SIZE:
yield from _process_files_batch(
files_to_process, convert_func, self.batch_size
)
for folder_id in folder_ids
]
)
for files_batch in file_batches:
doc_batch = []
for file in files_batch:
try:
# Skip files that are shortcuts
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
logger.info("Ignoring Drive Shortcut Filetype")
continue
files_to_process = []
if self.only_org_public:
if "permissions" not in file:
continue
if not any(
permission["type"] == "domain"
for permission in file["permissions"]
):
continue
try:
text_contents = extract_text(file, service) or ""
except HttpError as e:
reason = (
e.error_details[0]["reason"]
if e.error_details
else e.reason
)
message = (
e.error_details[0]["message"]
if e.error_details
else e.reason
)
# these errors don't represent a failure in the connector, but simply files
# that can't / shouldn't be indexed
ERRORS_TO_CONTINUE_ON = [
"cannotExportFile",
"exportSizeLimitExceeded",
"cannotDownloadFile",
]
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:
logger.warning(
f"Could not export file '{file['name']}' due to '{message}', skipping..."
)
continue
raise
doc_batch.append(
Document(
id=file["webViewLink"],
sections=[
Section(link=file["webViewLink"], text=text_contents)
],
source=DocumentSource.GOOGLE_DRIVE,
semantic_identifier=file["name"],
doc_updated_at=datetime.fromisoformat(
file["modifiedTime"]
).astimezone(timezone.utc),
metadata={} if text_contents else {IGNORE_FOR_QA: "True"},
additional_info=file.get("id"),
)
)
except Exception as e:
if not self.continue_on_failure:
raise e
logger.exception(
"Ran into exception when pulling a file from Google Drive"
)
yield doc_batch
# Process any remaining files
if files_to_process:
yield from _process_files_batch(
files_to_process, convert_func, self.batch_size
)
def load_from_state(self) -> GenerateDocumentsOutput:
yield from self._fetch_docs_from_drive()
try:
yield from self._extract_docs_from_google_drive()
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
# need to subtract 10 minutes from start time to account for modifiedTime
# propogation if a document is modified, it takes some time for the API to
# reflect these changes if we do not have an offset, then we may "miss" the
# update when polling
yield from self._fetch_docs_from_drive(start, end)
try:
yield from self._extract_docs_from_google_drive(start, end)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
def _extract_slim_docs_from_google_drive(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
slim_batch = []
for file in self._fetch_drive_items(
is_slim=True,
start=start,
end=end,
):
if doc := build_slim_document(file):
slim_batch.append(doc)
if len(slim_batch) >= SLIM_BATCH_SIZE:
yield slim_batch
slim_batch = []
yield slim_batch
if __name__ == "__main__":
import json
import os
service_account_json_path = os.environ.get("GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH")
if not service_account_json_path:
raise ValueError(
"Please set GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH environment variable"
)
with open(service_account_json_path) as f:
creds = json.load(f)
credentials_dict = {
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: json.dumps(creds),
}
delegated_user = os.environ.get("GOOGLE_DRIVE_DELEGATED_USER")
if delegated_user:
credentials_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user
connector = GoogleDriveConnector(include_shared=True, follow_shortcuts=True)
connector.load_credentials(credentials_dict)
document_batch_generator = connector.load_from_state()
for document_batch in document_batch_generator:
print(document_batch)
break
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
try:
yield from self._extract_slim_docs_from_google_drive(start, end)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e

View File

@@ -1,229 +0,0 @@
import json
from typing import cast
from urllib.parse import parse_qs
from urllib.parse import ParseResult
from urllib.parse import urlparse
from google.auth.transport.requests import Request # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from sqlalchemy.orm import Session
from danswer.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import KV_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from danswer.connectors.google_drive.constants import BASE_SCOPES
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
from danswer.db.credentials import update_credential_json
from danswer.db.models import User
from danswer.key_value_store.factory import get_kv_store
from danswer.server.documents.models import CredentialBase
from danswer.server.documents.models import GoogleAppCredentials
from danswer.server.documents.models import GoogleServiceAccountKey
from danswer.utils.logger import setup_logger
logger = setup_logger()
def build_gdrive_scopes() -> list[str]:
base_scopes: list[str] = BASE_SCOPES
permissions_scopes: list[str] = FETCH_PERMISSIONS_SCOPES
groups_scopes: list[str] = FETCH_GROUPS_SCOPES
if ENTERPRISE_EDITION_ENABLED:
return base_scopes + permissions_scopes + groups_scopes
return base_scopes + permissions_scopes
def _build_frontend_google_drive_redirect() -> str:
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
def get_google_drive_creds_for_authorized_user(
token_json_str: str, scopes: list[str] = build_gdrive_scopes()
) -> OAuthCredentials | None:
creds_json = json.loads(token_json_str)
creds = OAuthCredentials.from_authorized_user_info(creds_json, scopes)
if creds.valid:
return creds
if creds.expired and creds.refresh_token:
try:
creds.refresh(Request())
if creds.valid:
logger.notice("Refreshed Google Drive tokens.")
return creds
except Exception as e:
logger.exception(f"Failed to refresh google drive access token due to: {e}")
return None
return None
def _get_google_drive_creds_for_service_account(
service_account_key_json_str: str, scopes: list[str] = build_gdrive_scopes()
) -> ServiceAccountCredentials | None:
service_account_key = json.loads(service_account_key_json_str)
creds = ServiceAccountCredentials.from_service_account_info(
service_account_key, scopes=scopes
)
if not creds.valid or not creds.expired:
creds.refresh(Request())
return creds if creds.valid else None
def get_google_drive_creds(
credentials: dict[str, str], scopes: list[str] = build_gdrive_scopes()
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
oauth_creds = None
service_creds = None
new_creds_dict = None
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY])
oauth_creds = get_google_drive_creds_for_authorized_user(
token_json_str=access_token_json_str, scopes=scopes
)
# tell caller to update token stored in DB if it has changed
# (e.g. the token has been refreshed)
new_creds_json_str = oauth_creds.to_json() if oauth_creds else ""
if new_creds_json_str != access_token_json_str:
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
service_account_key_json_str = credentials[
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
]
service_creds = _get_google_drive_creds_for_service_account(
service_account_key_json_str=service_account_key_json_str,
scopes=scopes,
)
# "Impersonate" a user if one is specified
delegated_user_email = cast(
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
)
if delegated_user_email:
service_creds = (
service_creds.with_subject(delegated_user_email)
if service_creds
else None
)
creds: ServiceAccountCredentials | OAuthCredentials | None = (
oauth_creds or service_creds
)
if creds is None:
raise PermissionError(
"Unable to access Google Drive - unknown credential structure."
)
return creds, new_creds_dict
def verify_csrf(credential_id: int, state: str) -> None:
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
if csrf != state:
raise PermissionError(
"State from Google Drive Connector callback does not match expected"
)
def get_auth_url(credential_id: int) -> str:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
scopes=build_gdrive_scopes(),
redirect_uri=_build_frontend_google_drive_redirect(),
)
auth_url, _ = flow.authorization_url(prompt="consent")
parsed_url = cast(ParseResult, urlparse(auth_url))
params = parse_qs(parsed_url.query)
get_kv_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
return str(auth_url)
def update_credential_access_tokens(
auth_code: str,
credential_id: int,
user: User,
db_session: Session,
) -> OAuthCredentials | None:
app_credentials = get_google_app_cred()
flow = InstalledAppFlow.from_client_config(
app_credentials.model_dump(),
scopes=build_gdrive_scopes(),
redirect_uri=_build_frontend_google_drive_redirect(),
)
flow.fetch_token(code=auth_code)
creds = flow.credentials
token_json_str = creds.to_json()
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str}
if not update_credential_json(credential_id, new_creds_dict, user, db_session):
return None
return creds
def build_service_account_creds(
source: DocumentSource,
delegated_user_email: str | None = None,
) -> CredentialBase:
service_account_key = get_service_account_key()
credential_dict = {
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: service_account_key.json(),
}
if delegated_user_email:
credential_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user_email
return CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=DocumentSource.GOOGLE_DRIVE,
)
def get_google_app_cred() -> GoogleAppCredentials:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
return GoogleAppCredentials(**json.loads(creds_str))
def upsert_google_app_cred(app_credentials: GoogleAppCredentials) -> None:
get_kv_store().store(KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True)
def delete_google_app_cred() -> None:
get_kv_store().delete(KV_GOOGLE_DRIVE_CRED_KEY)
def get_service_account_key() -> GoogleServiceAccountKey:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
return GoogleServiceAccountKey(**json.loads(creds_str))
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
get_kv_store().store(
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
def delete_service_account_key() -> None:
get_kv_store().delete(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)

View File

@@ -1,7 +1,4 @@
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens"
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user"
BASE_SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
FETCH_PERMISSIONS_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"]
FETCH_GROUPS_SCOPES = ["https://www.googleapis.com/auth/cloud-identity.groups.readonly"]
UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now
DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder"
DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut"
DRIVE_FILE_TYPE = "application/vnd.google-apps.file"

View File

@@ -0,0 +1,197 @@
import io
from datetime import datetime
from datetime import timezone
from googleapiclient.errors import HttpError # type: ignore
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
from danswer.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
from danswer.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT
from danswer.connectors.google_drive.models import GDriveMimeType
from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.connectors.google_drive.section_extraction import get_document_sections
from danswer.connectors.google_utils.resources import GoogleDocsService
from danswer.connectors.google_utils.resources import GoogleDriveService
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.models import SlimDocument
from danswer.file_processing.extract_file_text import docx_to_text
from danswer.file_processing.extract_file_text import pptx_to_text
from danswer.file_processing.extract_file_text import read_pdf_file
from danswer.file_processing.unstructured import get_unstructured_api_key
from danswer.file_processing.unstructured import unstructured_to_text
from danswer.utils.logger import setup_logger
logger = setup_logger()
# these errors don't represent a failure in the connector, but simply files
# that can't / shouldn't be indexed
ERRORS_TO_CONTINUE_ON = [
"cannotExportFile",
"exportSizeLimitExceeded",
"cannotDownloadFile",
]
def _extract_sections_basic(
file: dict[str, str], service: GoogleDriveService
) -> list[Section]:
mime_type = file["mimeType"]
link = file["webViewLink"]
if mime_type not in set(item.value for item in GDriveMimeType):
# Unsupported file types can still have a title, finding this way is still useful
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
try:
if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
GDriveMimeType.SPREADSHEET.value,
]:
export_mime_type = (
"text/plain"
if mime_type != GDriveMimeType.SPREADSHEET.value
else "text/csv"
)
text = (
service.files()
.export(fileId=file["id"], mimeType=export_mime_type)
.execute()
.decode("utf-8")
)
return [Section(link=link, text=text)]
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,
]:
return [
Section(
link=link,
text=service.files()
.get_media(fileId=file["id"])
.execute()
.decode("utf-8"),
)
]
if mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
GDriveMimeType.PDF.value,
]:
response = service.files().get_media(fileId=file["id"]).execute()
if get_unstructured_api_key():
return [
Section(
link=link,
text=unstructured_to_text(
file=io.BytesIO(response),
file_name=file.get("name", file["id"]),
),
)
]
if mime_type == GDriveMimeType.WORD_DOC.value:
return [
Section(link=link, text=docx_to_text(file=io.BytesIO(response)))
]
elif mime_type == GDriveMimeType.PDF.value:
text, _ = read_pdf_file(file=io.BytesIO(response))
return [Section(link=link, text=text)]
elif mime_type == GDriveMimeType.POWERPOINT.value:
return [
Section(link=link, text=pptx_to_text(file=io.BytesIO(response)))
]
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
except Exception:
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
def convert_drive_item_to_document(
file: GoogleDriveFileType,
drive_service: GoogleDriveService,
docs_service: GoogleDocsService,
) -> Document | None:
try:
# Skip files that are shortcuts
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
logger.info("Ignoring Drive Shortcut Filetype")
return None
# Skip files that are folders
if file.get("mimeType") == DRIVE_FOLDER_TYPE:
logger.info("Ignoring Drive Folder Filetype")
return None
sections: list[Section] = []
# Special handling for Google Docs to preserve structure, link
# to headers
if file.get("mimeType") == GDriveMimeType.DOC.value:
try:
sections = get_document_sections(docs_service, file["id"])
except Exception as e:
logger.warning(
f"Ran into exception '{e}' when pulling sections from Google Doc '{file['name']}'."
" Falling back to basic extraction."
)
# NOTE: this will run for either (1) the above failed or (2) the file is not a Google Doc
if not sections:
try:
# For all other file types just extract the text
sections = _extract_sections_basic(file, drive_service)
except HttpError as e:
reason = e.error_details[0]["reason"] if e.error_details else e.reason
message = e.error_details[0]["message"] if e.error_details else e.reason
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:
logger.warning(
f"Could not export file '{file['name']}' due to '{message}', skipping..."
)
return None
raise
if not sections:
return None
return Document(
id=file["webViewLink"],
sections=sections,
source=DocumentSource.GOOGLE_DRIVE,
semantic_identifier=file["name"],
doc_updated_at=datetime.fromisoformat(file["modifiedTime"]).astimezone(
timezone.utc
),
metadata={}
if any(section.text for section in sections)
else {IGNORE_FOR_QA: "True"},
additional_info=file.get("id"),
)
except Exception as e:
if not CONTINUE_ON_CONNECTOR_FAILURE:
raise e
logger.exception("Ran into exception when pulling a file from Google Drive")
return None
def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None:
# Skip files that are folders or shortcuts
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
return None
return SlimDocument(
id=file["webViewLink"],
perm_sync_data={
"doc_id": file.get("id"),
"permissions": file.get("permissions", []),
"permission_ids": file.get("permissionIds", []),
"name": file.get("name"),
"owner_email": file.get("owners", [{}])[0].get("emailAddress"),
},
)

View File

@@ -0,0 +1,222 @@
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from typing import Any
from googleapiclient.discovery import Resource # type: ignore
from danswer.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
from danswer.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.connectors.google_utils.google_utils import execute_paginated_retrieval
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.utils.logger import setup_logger
logger = setup_logger()
FILE_FIELDS = (
"nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, "
"shortcutDetails, owners(emailAddress))"
)
SLIM_FILE_FIELDS = (
"nextPageToken, files(mimeType, id, name, permissions(emailAddress, type), "
"permissionIds, webViewLink, owners(emailAddress))"
)
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
def _generate_time_range_filter(
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> str:
time_range_filter = ""
if start is not None:
time_start = datetime.utcfromtimestamp(start).isoformat() + "Z"
time_range_filter += f" and modifiedTime >= '{time_start}'"
if end is not None:
time_stop = datetime.utcfromtimestamp(end).isoformat() + "Z"
time_range_filter += f" and modifiedTime <= '{time_stop}'"
return time_range_filter
def _get_folders_in_parent(
service: Resource,
parent_id: str | None = None,
) -> Iterator[GoogleDriveFileType]:
# Follow shortcuts to folders
query = f"(mimeType = '{DRIVE_FOLDER_TYPE}' or mimeType = '{DRIVE_SHORTCUT_TYPE}')"
query += " and trashed = false"
if parent_id:
query += f" and '{parent_id}' in parents"
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
corpora="allDrives",
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields=FOLDER_FIELDS,
q=query,
):
yield file
def _get_files_in_parent(
service: Resource,
parent_id: str,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
is_slim: bool = False,
) -> Iterator[GoogleDriveFileType]:
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents"
query += " and trashed = false"
query += _generate_time_range_filter(start, end)
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
corpora="allDrives",
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
):
yield file
def crawl_folders_for_files(
service: Resource,
parent_id: str,
traversed_parent_ids: set[str],
update_traversed_ids_func: Callable[[str], None],
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
"""
This function starts crawling from any folder. It is slower though.
"""
if parent_id in traversed_parent_ids:
logger.info(f"Skipping subfolder since already traversed: {parent_id}")
return
found_files = False
for file in _get_files_in_parent(
service=service,
start=start,
end=end,
parent_id=parent_id,
):
found_files = True
yield file
if found_files:
update_traversed_ids_func(parent_id)
for subfolder in _get_folders_in_parent(
service=service,
parent_id=parent_id,
):
logger.info("Fetching all files in subfolder: " + subfolder["name"])
yield from crawl_folders_for_files(
service=service,
parent_id=subfolder["id"],
traversed_parent_ids=traversed_parent_ids,
update_traversed_ids_func=update_traversed_ids_func,
start=start,
end=end,
)
def get_files_in_shared_drive(
service: Resource,
drive_id: str,
is_slim: bool = False,
update_traversed_ids_func: Callable[[str], None] = lambda _: None,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
# If we know we are going to folder crawl later, we can cache the folders here
# Get all folders being queried and add them to the traversed set
query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
query += " and trashed = false"
found_folders = False
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
corpora="drive",
driveId=drive_id,
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields="nextPageToken, files(id)",
q=query,
):
update_traversed_ids_func(file["id"])
found_folders = True
if found_folders:
update_traversed_ids_func(drive_id)
# Get all files in the shared drive
query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
query += " and trashed = false"
query += _generate_time_range_filter(start, end)
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
corpora="drive",
driveId=drive_id,
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
)
def get_all_files_in_my_drive(
service: Any,
update_traversed_ids_func: Callable,
is_slim: bool = False,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
# If we know we are going to folder crawl later, we can cache the folders here
# Get all folders being queried and add them to the traversed set
query = "trashed = false and 'me' in owners"
found_folders = False
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="user",
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
):
update_traversed_ids_func(file["id"])
found_folders = True
if found_folders:
update_traversed_ids_func(get_root_folder_id(service))
# Then get the files
query = "trashed = false and 'me' in owners"
query += _generate_time_range_filter(start, end)
fields = "files(id, name, mimeType, webViewLink, modifiedTime, createdTime)"
if not is_slim:
fields += ", files(permissions, permissionIds, owners)"
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="user",
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
)
# Just in case we need to get the root folder id
def get_root_folder_id(service: Resource) -> str:
# we dont paginate here because there is only one root folder per user
# https://developers.google.com/drive/api/guides/v2-to-v3-reference
return service.files().get(fileId="root", fields="id").execute()["id"]

View File

@@ -0,0 +1,18 @@
from enum import Enum
from typing import Any
class GDriveMimeType(str, Enum):
DOC = "application/vnd.google-apps.document"
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
PDF = "application/pdf"
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
PPT = "application/vnd.google-apps.presentation"
POWERPOINT = (
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
)
PLAIN_TEXT = "text/plain"
MARKDOWN = "text/markdown"
GoogleDriveFileType = dict[str, Any]

View File

@@ -0,0 +1,105 @@
from typing import Any
from pydantic import BaseModel
from danswer.connectors.google_utils.resources import GoogleDocsService
from danswer.connectors.models import Section
class CurrentHeading(BaseModel):
id: str
text: str
def _build_gdoc_section_link(doc_id: str, heading_id: str) -> str:
"""Builds a Google Doc link that jumps to a specific heading"""
# NOTE: doesn't support docs with multiple tabs atm, if we need that ask
# @Chris
return (
f"https://docs.google.com/document/d/{doc_id}/edit?tab=t.0#heading={heading_id}"
)
def _extract_id_from_heading(paragraph: dict[str, Any]) -> str:
"""Extracts the id from a heading paragraph element"""
return paragraph["paragraphStyle"]["headingId"]
def _extract_text_from_paragraph(paragraph: dict[str, Any]) -> str:
"""Extracts the text content from a paragraph element"""
text_elements = []
for element in paragraph.get("elements", []):
if "textRun" in element:
text_elements.append(element["textRun"].get("content", ""))
return "".join(text_elements)
def get_document_sections(
docs_service: GoogleDocsService,
doc_id: str,
) -> list[Section]:
"""Extracts sections from a Google Doc, including their headings and content"""
# Fetch the document structure
doc = docs_service.documents().get(documentId=doc_id).execute()
# Get the content
content = doc.get("body", {}).get("content", [])
sections: list[Section] = []
current_section: list[str] = []
current_heading: CurrentHeading | None = None
for element in content:
if "paragraph" not in element:
continue
paragraph = element["paragraph"]
# Check if this is a heading
if (
"paragraphStyle" in paragraph
and "namedStyleType" in paragraph["paragraphStyle"]
):
style = paragraph["paragraphStyle"]["namedStyleType"]
is_heading = style.startswith("HEADING_")
is_title = style.startswith("TITLE")
if is_heading or is_title:
# If we were building a previous section, add it to sections list
if current_heading is not None and current_section:
heading_text = current_heading.text
section_text = f"{heading_text}\n" + "\n".join(current_section)
sections.append(
Section(
text=section_text.strip(),
link=_build_gdoc_section_link(doc_id, current_heading.id),
)
)
current_section = []
# Start new heading
heading_id = _extract_id_from_heading(paragraph)
heading_text = _extract_text_from_paragraph(paragraph)
current_heading = CurrentHeading(
id=heading_id,
text=heading_text,
)
continue
# Add content to current section
if current_heading is not None:
text = _extract_text_from_paragraph(paragraph)
if text.strip():
current_section.append(text)
# Don't forget to add the last section
if current_heading is not None and current_section:
section_text = f"{current_heading.text}\n" + "\n".join(current_section)
sections.append(
Section(
text=section_text.strip(),
link=_build_gdoc_section_link(doc_id, current_heading.id),
)
)
return sections

View File

@@ -0,0 +1,107 @@
import json
from typing import cast
from google.auth.transport.requests import Request # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from danswer.configs.constants import DocumentSource
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_TOKEN_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
GOOGLE_SCOPES,
)
from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_google_oauth_creds(
token_json_str: str, source: DocumentSource
) -> OAuthCredentials | None:
creds_json = json.loads(token_json_str)
creds = OAuthCredentials.from_authorized_user_info(
info=creds_json,
scopes=GOOGLE_SCOPES[source],
)
if creds.valid:
return creds
if creds.expired and creds.refresh_token:
try:
creds.refresh(Request())
if creds.valid:
logger.notice("Refreshed Google Drive tokens.")
return creds
except Exception:
logger.exception("Failed to refresh google drive access token due to:")
return None
return None
def get_google_creds(
credentials: dict[str, str],
source: DocumentSource,
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
"""Checks for two different types of credentials.
(1) A credential which holds a token acquired via a user going thorough
the Google OAuth flow.
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.
"""
oauth_creds = None
service_creds = None
new_creds_dict = None
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
# OAUTH
access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY])
oauth_creds = get_google_oauth_creds(
token_json_str=access_token_json_str, source=source
)
# tell caller to update token stored in DB if it has changed
# (e.g. the token has been refreshed)
new_creds_json_str = oauth_creds.to_json() if oauth_creds else ""
if new_creds_json_str != access_token_json_str:
new_creds_dict = {
DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[
DB_CREDENTIALS_PRIMARY_ADMIN_KEY
],
}
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
# SERVICE ACCOUNT
service_account_key_json_str = credentials[
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
]
service_account_key = json.loads(service_account_key_json_str)
service_creds = ServiceAccountCredentials.from_service_account_info(
service_account_key, scopes=GOOGLE_SCOPES[source]
)
if not service_creds.valid or not service_creds.expired:
service_creds.refresh(Request())
if not service_creds.valid:
raise PermissionError(
f"Unable to access {source} - service account credentials are invalid."
)
creds: ServiceAccountCredentials | OAuthCredentials | None = (
oauth_creds or service_creds
)
if creds is None:
raise PermissionError(
f"Unable to access {source} - unknown credential structure."
)
return creds, new_creds_dict

View File

@@ -0,0 +1,237 @@
import json
from typing import cast
from urllib.parse import parse_qs
from urllib.parse import ParseResult
from urllib.parse import urlparse
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from sqlalchemy.orm import Session
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import KV_CRED_KEY
from danswer.configs.constants import KV_GMAIL_CRED_KEY
from danswer.configs.constants import KV_GMAIL_SERVICE_ACCOUNT_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from danswer.connectors.google_utils.resources import get_drive_service
from danswer.connectors.google_utils.resources import get_gmail_service
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_TOKEN_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
GOOGLE_SCOPES,
)
from danswer.connectors.google_utils.shared_constants import (
MISSING_SCOPES_ERROR_STR,
)
from danswer.connectors.google_utils.shared_constants import (
ONYX_SCOPE_INSTRUCTIONS,
)
from danswer.db.credentials import update_credential_json
from danswer.db.models import User
from danswer.key_value_store.factory import get_kv_store
from danswer.server.documents.models import CredentialBase
from danswer.server.documents.models import GoogleAppCredentials
from danswer.server.documents.models import GoogleServiceAccountKey
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _build_frontend_google_drive_redirect(source: DocumentSource) -> str:
if source == DocumentSource.GOOGLE_DRIVE:
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
elif source == DocumentSource.GMAIL:
return f"{WEB_DOMAIN}/admin/connectors/gmail/auth/callback"
else:
raise ValueError(f"Unsupported source: {source}")
def _get_current_oauth_user(creds: OAuthCredentials, source: DocumentSource) -> str:
if source == DocumentSource.GOOGLE_DRIVE:
drive_service = get_drive_service(creds)
user_info = (
drive_service.about()
.get(
fields="user(emailAddress)",
)
.execute()
)
email = user_info.get("user", {}).get("emailAddress")
elif source == DocumentSource.GMAIL:
gmail_service = get_gmail_service(creds)
user_info = (
gmail_service.users()
.getProfile(
userId="me",
fields="emailAddress",
)
.execute()
)
email = user_info.get("emailAddress")
else:
raise ValueError(f"Unsupported source: {source}")
return email
def verify_csrf(credential_id: int, state: str) -> None:
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
if csrf != state:
raise PermissionError(
"State from Google Drive Connector callback does not match expected"
)
def update_credential_access_tokens(
auth_code: str,
credential_id: int,
user: User,
db_session: Session,
source: DocumentSource,
) -> OAuthCredentials | None:
app_credentials = get_google_app_cred(source)
flow = InstalledAppFlow.from_client_config(
app_credentials.model_dump(),
scopes=GOOGLE_SCOPES[source],
redirect_uri=_build_frontend_google_drive_redirect(source),
)
flow.fetch_token(code=auth_code)
creds = flow.credentials
token_json_str = creds.to_json()
# Get user email from Google API so we know who
# the primary admin is for this connector
try:
email = _get_current_oauth_user(creds, source)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
new_creds_dict = {
DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
}
if not update_credential_json(credential_id, new_creds_dict, user, db_session):
return None
return creds
def build_service_account_creds(
source: DocumentSource,
primary_admin_email: str | None = None,
) -> CredentialBase:
service_account_key = get_service_account_key(source=source)
credential_dict = {
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: service_account_key.json(),
}
if primary_admin_email:
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = primary_admin_email
return CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=source,
)
def get_auth_url(credential_id: int, source: DocumentSource) -> str:
if source == DocumentSource.GOOGLE_DRIVE:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
elif source == DocumentSource.GMAIL:
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
else:
raise ValueError(f"Unsupported source: {source}")
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
scopes=GOOGLE_SCOPES[source],
redirect_uri=_build_frontend_google_drive_redirect(source),
)
auth_url, _ = flow.authorization_url(prompt="consent")
parsed_url = cast(ParseResult, urlparse(auth_url))
params = parse_qs(parsed_url.query)
get_kv_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
return str(auth_url)
def get_google_app_cred(source: DocumentSource) -> GoogleAppCredentials:
if source == DocumentSource.GOOGLE_DRIVE:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
elif source == DocumentSource.GMAIL:
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
else:
raise ValueError(f"Unsupported source: {source}")
return GoogleAppCredentials(**json.loads(creds_str))
def upsert_google_app_cred(
app_credentials: GoogleAppCredentials, source: DocumentSource
) -> None:
if source == DocumentSource.GOOGLE_DRIVE:
get_kv_store().store(
KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
)
elif source == DocumentSource.GMAIL:
get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True)
else:
raise ValueError(f"Unsupported source: {source}")
def delete_google_app_cred(source: DocumentSource) -> None:
if source == DocumentSource.GOOGLE_DRIVE:
get_kv_store().delete(KV_GOOGLE_DRIVE_CRED_KEY)
elif source == DocumentSource.GMAIL:
get_kv_store().delete(KV_GMAIL_CRED_KEY)
else:
raise ValueError(f"Unsupported source: {source}")
def get_service_account_key(source: DocumentSource) -> GoogleServiceAccountKey:
if source == DocumentSource.GOOGLE_DRIVE:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
elif source == DocumentSource.GMAIL:
creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
else:
raise ValueError(f"Unsupported source: {source}")
return GoogleServiceAccountKey(**json.loads(creds_str))
def upsert_service_account_key(
service_account_key: GoogleServiceAccountKey, source: DocumentSource
) -> None:
if source == DocumentSource.GOOGLE_DRIVE:
get_kv_store().store(
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY,
service_account_key.json(),
encrypt=True,
)
elif source == DocumentSource.GMAIL:
get_kv_store().store(
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
else:
raise ValueError(f"Unsupported source: {source}")
def delete_service_account_key(source: DocumentSource) -> None:
if source == DocumentSource.GOOGLE_DRIVE:
get_kv_store().delete(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
elif source == DocumentSource.GMAIL:
get_kv_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
else:
raise ValueError(f"Unsupported source: {source}")

View File

@@ -0,0 +1,125 @@
import re
import time
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import Any
from googleapiclient.errors import HttpError # type: ignore
from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.utils.logger import setup_logger
from danswer.utils.retry_wrapper import retry_builder
logger = setup_logger()
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. Trying to combat here by adding a very
# long retry period (~20 minutes of trying every minute)
add_retries = retry_builder(tries=50, max_delay=30)
def _execute_with_retry(request: Any) -> Any:
max_attempts = 10
attempt = 1
while attempt < max_attempts:
# Note for reasons unknown, the Google API will sometimes return a 429
# and even after waiting the retry period, it will return another 429.
# It could be due to a few possibilities:
# 1. Other things are also requesting from the Gmail API with the same key
# 2. It's a rolling rate limit so the moment we get some amount of requests cleared, we hit it again very quickly
# 3. The retry-after has a maximum and we've already hit the limit for the day
# or it's something else...
try:
return request.execute()
except HttpError as error:
attempt += 1
if error.resp.status == 429:
# Attempt to get 'Retry-After' from headers
retry_after = error.resp.get("Retry-After")
if retry_after:
sleep_time = int(retry_after)
else:
# Extract 'Retry after' timestamp from error message
match = re.search(
r"Retry after (\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+Z)",
str(error),
)
if match:
retry_after_timestamp = match.group(1)
retry_after_dt = datetime.strptime(
retry_after_timestamp, "%Y-%m-%dT%H:%M:%S.%fZ"
).replace(tzinfo=timezone.utc)
current_time = datetime.now(timezone.utc)
sleep_time = max(
int((retry_after_dt - current_time).total_seconds()),
0,
)
else:
logger.error(
f"No Retry-After header or timestamp found in error message: {error}"
)
sleep_time = 60
sleep_time += 3 # Add a buffer to be safe
logger.info(
f"Rate limit exceeded. Attempt {attempt}/{max_attempts}. Sleeping for {sleep_time} seconds."
)
time.sleep(sleep_time)
else:
raise
# If we've exhausted all attempts
raise Exception(f"Failed to execute request after {max_attempts} attempts")
def execute_paginated_retrieval(
retrieval_function: Callable,
list_key: str | None = None,
continue_on_404_or_403: bool = False,
**kwargs: Any,
) -> Iterator[GoogleDriveFileType]:
"""Execute a paginated retrieval from Google Drive API
Args:
retrieval_function: The specific list function to call (e.g., service.files().list)
**kwargs: Arguments to pass to the list function
"""
next_page_token = ""
while next_page_token is not None:
request_kwargs = kwargs.copy()
if next_page_token:
request_kwargs["pageToken"] = next_page_token
try:
results = retrieval_function(**request_kwargs).execute()
except HttpError as e:
if e.resp.status >= 500:
results = add_retries(
lambda: retrieval_function(**request_kwargs).execute()
)()
elif e.resp.status == 404 or e.resp.status == 403:
if continue_on_404_or_403:
logger.warning(f"Error executing request: {e}")
results = {}
else:
raise e
elif e.resp.status == 429:
results = _execute_with_retry(
lambda: retrieval_function(**request_kwargs).execute()
)
else:
logger.exception("Error executing request:")
raise e
next_page_token = results.get("nextPageToken")
if list_key:
for item in results.get(list_key, []):
yield item
else:
yield results

View File

@@ -0,0 +1,63 @@
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.discovery import build # type: ignore
from googleapiclient.discovery import Resource # type: ignore
class GoogleDriveService(Resource):
pass
class GoogleDocsService(Resource):
pass
class AdminService(Resource):
pass
class GmailService(Resource):
pass
def _get_google_service(
service_name: str,
service_version: str,
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GoogleDriveService | GoogleDocsService | AdminService | GmailService:
if isinstance(creds, ServiceAccountCredentials):
creds = creds.with_subject(user_email)
service = build(service_name, service_version, credentials=creds)
elif isinstance(creds, OAuthCredentials):
service = build(service_name, service_version, credentials=creds)
return service
def get_google_docs_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GoogleDocsService:
return _get_google_service("docs", "v1", creds, user_email)
def get_drive_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GoogleDriveService:
return _get_google_service("drive", "v3", creds, user_email)
def get_admin_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> AdminService:
return _get_google_service("admin", "directory_v1", creds, user_email)
def get_gmail_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GmailService:
return _get_google_service("gmail", "v1", creds, user_email)

View File

@@ -0,0 +1,40 @@
from danswer.configs.constants import DocumentSource
# NOTE: do not need https://www.googleapis.com/auth/documents.readonly
# this is counted under `/auth/drive.readonly`
GOOGLE_SCOPES = {
DocumentSource.GOOGLE_DRIVE: [
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive.metadata.readonly",
"https://www.googleapis.com/auth/admin.directory.group.readonly",
"https://www.googleapis.com/auth/admin.directory.user.readonly",
],
DocumentSource.GMAIL: [
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/admin.directory.user.readonly",
"https://www.googleapis.com/auth/admin.directory.group.readonly",
],
}
# This is the Oauth token
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
# This is the service account key
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
# The email saved for both auth types
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
USER_FIELDS = "nextPageToken, users(primaryEmail)"
# Error message substrings
MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested"
# Documentation and error messages
SCOPE_DOC_URL = "https://docs.danswer.dev/connectors/google_drive/overview"
ONYX_SCOPE_INSTRUCTIONS = (
"You have upgraded Danswer without updating the Google Auth scopes. "
f"Please refer to the documentation to learn how to update the scopes: {SCOPE_DOC_URL}"
)
# This is the maximum number of threads that can be retrieved at once
SLIM_BATCH_SIZE = 500

View File

@@ -56,7 +56,11 @@ class PollConnector(BaseConnector):
class SlimConnector(BaseConnector):
@abc.abstractmethod
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
raise NotImplementedError

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import builtins
import functools
import itertools
import tempfile
from typing import Any
from unittest import mock
from urllib.parse import urlparse
@@ -18,6 +19,8 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
pywikibot.config.base_dir = tempfile.TemporaryDirectory().name
@mock.patch.object(
builtins, "print", lambda *args: logger.info("\t".join(map(str, args)))

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import datetime
import itertools
import tempfile
from collections.abc import Generator
from collections.abc import Iterator
from typing import Any
@@ -25,6 +26,8 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
pywikibot.config.base_dir = tempfile.TemporaryDirectory().name
def pywikibot_timestamp_to_utc_datetime(
timestamp: pywikibot.time.Timestamp,
@@ -121,7 +124,6 @@ class MediaWikiConnector(LoadConnector, PollConnector):
self.batch_size = batch_size
# short names can only have ascii letters and digits
self.family = family_class_dispatch(hostname, "WikipediaConnector")()
self.site = pywikibot.Site(fam=self.family, code=language_code)
self.categories = [

View File

@@ -251,7 +251,11 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
end_datetime = datetime.utcfromtimestamp(end)
return self._fetch_from_salesforce(start=start_datetime, end=end_datetime)
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
doc_metadata_list: list[SlimDocument] = []

View File

@@ -391,7 +391,11 @@ class SlackPollConnector(PollConnector, SlimConnector):
self.client = WebClient(token=bot_token)
return None
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
if self.client is None:
raise ConnectorMissingCredentialError("Slack")

View File

@@ -1,10 +1,7 @@
from collections.abc import Iterator
from typing import Any
import requests
from retry import retry
from zenpy import Zenpy # type: ignore
from zenpy.lib.api_objects import Ticket # type: ignore
from zenpy.lib.api_objects.help_centre_objects import Article # type: ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
@@ -20,43 +17,244 @@ from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.html_utils import parse_html_page_basic
from danswer.utils.retry_wrapper import retry_builder
def _article_to_document(article: Article, content_tags: dict[str, str]) -> Document:
author = BasicExpertInfo(
display_name=article.author.name, email=article.author.email
MAX_PAGE_SIZE = 30 # Zendesk API maximum
class ZendeskCredentialsNotSetUpError(PermissionError):
def __init__(self) -> None:
super().__init__(
"Zendesk Credentials are not set up, was load_credentials called?"
)
class ZendeskClient:
def __init__(self, subdomain: str, email: str, token: str):
self.base_url = f"https://{subdomain}.zendesk.com/api/v2"
self.auth = (f"{email}/token", token)
@retry_builder()
def make_request(self, endpoint: str, params: dict[str, Any]) -> dict[str, Any]:
response = requests.get(
f"{self.base_url}/{endpoint}", auth=self.auth, params=params
)
response.raise_for_status()
return response.json()
def _get_content_tag_mapping(client: ZendeskClient) -> dict[str, str]:
content_tags: dict[str, str] = {}
params = {"page[size]": MAX_PAGE_SIZE}
try:
while True:
data = client.make_request("guide/content_tags", params)
for tag in data.get("records", []):
content_tags[tag["id"]] = tag["name"]
# Check if there are more pages
if data.get("meta", {}).get("has_more", False):
params["page[after]"] = data["meta"]["after_cursor"]
else:
break
return content_tags
except Exception as e:
raise Exception(f"Error fetching content tags: {str(e)}")
def _get_articles(
client: ZendeskClient, start_time: int | None = None, page_size: int = MAX_PAGE_SIZE
) -> Iterator[dict[str, Any]]:
params = (
{"start_time": start_time, "page[size]": page_size}
if start_time
else {"page[size]": page_size}
)
update_time = time_str_to_utc(article.updated_at)
# build metadata
while True:
data = client.make_request("help_center/articles", params)
for article in data["articles"]:
yield article
if not data.get("meta", {}).get("has_more"):
break
params["page[after]"] = data["meta"]["after_cursor"]
def _get_tickets(
client: ZendeskClient, start_time: int | None = None
) -> Iterator[dict[str, Any]]:
params = {"start_time": start_time} if start_time else {"start_time": 0}
while True:
data = client.make_request("incremental/tickets.json", params)
for ticket in data["tickets"]:
yield ticket
if not data.get("end_of_stream", False):
params["start_time"] = data["end_time"]
else:
break
def _fetch_author(client: ZendeskClient, author_id: str) -> BasicExpertInfo | None:
author_data = client.make_request(f"users/{author_id}", {})
user = author_data.get("user")
return (
BasicExpertInfo(display_name=user.get("name"), email=user.get("email"))
if user and user.get("name") and user.get("email")
else None
)
def _article_to_document(
article: dict[str, Any],
content_tags: dict[str, str],
author_map: dict[str, BasicExpertInfo],
client: ZendeskClient,
) -> tuple[dict[str, BasicExpertInfo] | None, Document]:
author_id = article.get("author_id")
if not author_id:
author = None
else:
author = (
author_map.get(author_id)
if author_id in author_map
else _fetch_author(client, author_id)
)
new_author_mapping = {author_id: author} if author_id and author else None
updated_at = article.get("updated_at")
update_time = time_str_to_utc(updated_at) if updated_at else None
# Build metadata
metadata: dict[str, str | list[str]] = {
"labels": [str(label) for label in article.label_names if label],
"labels": [str(label) for label in article.get("label_names", []) if label],
"content_tags": [
content_tags[tag_id]
for tag_id in article.content_tag_ids
for tag_id in article.get("content_tag_ids", [])
if tag_id in content_tags
],
}
# remove empty values
# Remove empty values
metadata = {k: v for k, v in metadata.items() if v}
return Document(
id=f"article:{article.id}",
return new_author_mapping, Document(
id=f"article:{article['id']}",
sections=[
Section(link=article.html_url, text=parse_html_page_basic(article.body))
Section(
link=article.get("html_url"),
text=parse_html_page_basic(article["body"]),
)
],
source=DocumentSource.ZENDESK,
semantic_identifier=article.title,
semantic_identifier=article["title"],
doc_updated_at=update_time,
primary_owners=[author],
primary_owners=[author] if author else None,
metadata=metadata,
)
class ZendeskClientNotSetUpError(PermissionError):
def __init__(self) -> None:
super().__init__("Zendesk Client is not set up, was load_credentials called?")
def _get_comment_text(
comment: dict[str, Any],
author_map: dict[str, BasicExpertInfo],
client: ZendeskClient,
) -> tuple[dict[str, BasicExpertInfo] | None, str]:
author_id = comment.get("author_id")
if not author_id:
author = None
else:
author = (
author_map.get(author_id)
if author_id in author_map
else _fetch_author(client, author_id)
)
new_author_mapping = {author_id: author} if author_id and author else None
comment_text = f"Comment{' by ' + author.display_name if author and author.display_name else ''}"
comment_text += f"{' at ' + comment['created_at'] if comment.get('created_at') else ''}:\n{comment['body']}"
return new_author_mapping, comment_text
def _ticket_to_document(
ticket: dict[str, Any],
author_map: dict[str, BasicExpertInfo],
client: ZendeskClient,
default_subdomain: str,
) -> tuple[dict[str, BasicExpertInfo] | None, Document]:
submitter_id = ticket.get("submitter")
if not submitter_id:
submitter = None
else:
submitter = (
author_map.get(submitter_id)
if submitter_id in author_map
else _fetch_author(client, submitter_id)
)
new_author_mapping = (
{submitter_id: submitter} if submitter_id and submitter else None
)
updated_at = ticket.get("updated_at")
update_time = time_str_to_utc(updated_at) if updated_at else None
metadata: dict[str, str | list[str]] = {}
if status := ticket.get("status"):
metadata["status"] = status
if priority := ticket.get("priority"):
metadata["priority"] = priority
if tags := ticket.get("tags"):
metadata["tags"] = tags
if ticket_type := ticket.get("type"):
metadata["ticket_type"] = ticket_type
# Fetch comments for the ticket
comments_data = client.make_request(f"tickets/{ticket.get('id')}/comments", {})
comments = comments_data.get("comments", [])
comment_texts = []
for comment in comments:
new_author_mapping, comment_text = _get_comment_text(
comment, author_map, client
)
if new_author_mapping:
author_map.update(new_author_mapping)
comment_texts.append(comment_text)
comments_text = "\n\n".join(comment_texts)
subject = ticket.get("subject")
full_text = f"Ticket Subject:\n{subject}\n\nComments:\n{comments_text}"
ticket_url = ticket.get("url")
subdomain = (
ticket_url.split("//")[1].split(".zendesk.com")[0]
if ticket_url
else default_subdomain
)
ticket_display_url = (
f"https://{subdomain}.zendesk.com/agent/tickets/{ticket.get('id')}"
)
return new_author_mapping, Document(
id=f"zendesk_ticket_{ticket['id']}",
sections=[Section(link=ticket_display_url, text=full_text)],
source=DocumentSource.ZENDESK,
semantic_identifier=f"Ticket #{ticket['id']}: {subject or 'No Subject'}",
doc_updated_at=update_time,
primary_owners=[submitter] if submitter else None,
metadata=metadata,
)
class ZendeskConnector(LoadConnector, PollConnector):
@@ -66,44 +264,10 @@ class ZendeskConnector(LoadConnector, PollConnector):
content_type: str = "articles",
) -> None:
self.batch_size = batch_size
self.zendesk_client: Zenpy | None = None
self.content_tags: dict[str, str] = {}
self.content_type = content_type
@retry(tries=3, delay=2, backoff=2)
def _set_content_tags(
self, subdomain: str, email: str, token: str, page_size: int = 30
) -> None:
# Construct the base URL
base_url = f"https://{subdomain}.zendesk.com/api/v2/guide/content_tags"
# Set up authentication
auth = (f"{email}/token", token)
# Set up pagination parameters
params = {"page[size]": page_size}
try:
while True:
# Make the GET request
response = requests.get(base_url, auth=auth, params=params)
# Check if the request was successful
if response.status_code == 200:
data = response.json()
content_tag_list = data.get("records", [])
for tag in content_tag_list:
self.content_tags[tag["id"]] = tag["name"]
# Check if there are more pages
if data.get("meta", {}).get("has_more", False):
params["page[after]"] = data["meta"]["after_cursor"]
else:
break
else:
raise Exception(f"Error: {response.status_code}\n{response.text}")
except Exception as e:
raise Exception(f"Error fetching content tags: {str(e)}")
self.subdomain = ""
# Fetch all tags ahead of time
self.content_tags: dict[str, str] = {}
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# Subdomain is actually the whole URL
@@ -112,87 +276,23 @@ class ZendeskConnector(LoadConnector, PollConnector):
.replace("https://", "")
.split(".zendesk.com")[0]
)
self.subdomain = subdomain
self.zendesk_client = Zenpy(
subdomain=subdomain,
email=credentials["zendesk_email"],
token=credentials["zendesk_token"],
)
self._set_content_tags(
subdomain,
credentials["zendesk_email"],
credentials["zendesk_token"],
self.client = ZendeskClient(
subdomain, credentials["zendesk_email"], credentials["zendesk_token"]
)
return None
def load_from_state(self) -> GenerateDocumentsOutput:
return self.poll_source(None, None)
def _ticket_to_document(self, ticket: Ticket) -> Document:
if self.zendesk_client is None:
raise ZendeskClientNotSetUpError()
owner = None
if ticket.requester and ticket.requester.name and ticket.requester.email:
owner = [
BasicExpertInfo(
display_name=ticket.requester.name, email=ticket.requester.email
)
]
update_time = time_str_to_utc(ticket.updated_at) if ticket.updated_at else None
metadata: dict[str, str | list[str]] = {}
if ticket.status is not None:
metadata["status"] = ticket.status
if ticket.priority is not None:
metadata["priority"] = ticket.priority
if ticket.tags:
metadata["tags"] = ticket.tags
if ticket.type is not None:
metadata["ticket_type"] = ticket.type
# Fetch comments for the ticket
comments = self.zendesk_client.tickets.comments(ticket=ticket)
# Combine all comments into a single text
comments_text = "\n\n".join(
[
f"Comment{f' by {comment.author.name}' if comment.author and comment.author.name else ''}"
f"{f' at {comment.created_at}' if comment.created_at else ''}:\n{comment.body}"
for comment in comments
if comment.body
]
)
# Combine ticket description and comments
description = (
ticket.description
if hasattr(ticket, "description") and ticket.description
else ""
)
full_text = f"Ticket Description:\n{description}\n\nComments:\n{comments_text}"
# Extract subdomain from ticket.url
subdomain = ticket.url.split("//")[1].split(".zendesk.com")[0]
# Build the html url for the ticket
ticket_url = f"https://{subdomain}.zendesk.com/agent/tickets/{ticket.id}"
return Document(
id=f"zendesk_ticket_{ticket.id}",
sections=[Section(link=ticket_url, text=full_text)],
source=DocumentSource.ZENDESK,
semantic_identifier=f"Ticket #{ticket.id}: {ticket.subject or 'No Subject'}",
doc_updated_at=update_time,
primary_owners=owner,
metadata=metadata,
)
def poll_source(
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
if self.zendesk_client is None:
raise ZendeskClientNotSetUpError()
if self.client is None:
raise ZendeskCredentialsNotSetUpError()
self.content_tags = _get_content_tag_mapping(self.client)
if self.content_type == "articles":
yield from self._poll_articles(start)
@@ -204,26 +304,30 @@ class ZendeskConnector(LoadConnector, PollConnector):
def _poll_articles(
self, start: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
articles = (
self.zendesk_client.help_center.articles(cursor_pagination=True) # type: ignore
if start is None
else self.zendesk_client.help_center.articles.incremental( # type: ignore
start_time=int(start)
)
)
articles = _get_articles(self.client, start_time=int(start) if start else None)
# This one is built on the fly as there may be more many more authors than tags
author_map: dict[str, BasicExpertInfo] = {}
doc_batch = []
for article in articles:
if (
article.body is None
or article.draft
article.get("body") is None
or article.get("draft")
or any(
label in ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
for label in article.label_names
for label in article.get("label_names", [])
)
):
continue
doc_batch.append(_article_to_document(article, self.content_tags))
new_author_map, documents = _article_to_document(
article, self.content_tags, author_map, self.client
)
if new_author_map:
author_map.update(new_author_map)
doc_batch.append(documents)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch.clear()
@@ -234,10 +338,14 @@ class ZendeskConnector(LoadConnector, PollConnector):
def _poll_tickets(
self, start: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
if self.zendesk_client is None:
raise ZendeskClientNotSetUpError()
if self.client is None:
raise ZendeskCredentialsNotSetUpError()
ticket_generator = self.zendesk_client.tickets.incremental(start_time=start)
author_map: dict[str, BasicExpertInfo] = {}
ticket_generator = _get_tickets(
self.client, start_time=int(start) if start else None
)
while True:
doc_batch = []
@@ -246,10 +354,20 @@ class ZendeskConnector(LoadConnector, PollConnector):
ticket = next(ticket_generator)
# Check if the ticket status is deleted and skip it if so
if ticket.status == "deleted":
if ticket.get("status") == "deleted":
continue
doc_batch.append(self._ticket_to_document(ticket))
new_author_map, documents = _ticket_to_document(
ticket=ticket,
author_map=author_map,
client=self.client,
default_subdomain=self.subdomain,
)
if new_author_map:
author_map.update(new_author_map)
doc_batch.append(documents)
if len(doc_batch) >= self.batch_size:
yield doc_batch
@@ -267,7 +385,6 @@ class ZendeskConnector(LoadConnector, PollConnector):
if __name__ == "__main__":
import os
import time
connector = ZendeskConnector()

View File

@@ -1,3 +1,5 @@
import os
from sqlalchemy.orm import Session
from danswer.db.models import SlackBotConfig
@@ -48,3 +50,16 @@ def validate_channel_names(
)
return cleaned_channel_names
# Scaling configurations for multi-tenant Slack bot handling
TENANT_LOCK_EXPIRATION = 1800 # How long a pod can hold exclusive access to a tenant before other pods can acquire it
TENANT_HEARTBEAT_INTERVAL = (
60 # How often pods send heartbeats to indicate they are still processing a tenant
)
TENANT_HEARTBEAT_EXPIRATION = 180 # How long before a tenant's heartbeat expires, allowing other pods to take over
TENANT_ACQUISITION_INTERVAL = (
60 # How often pods attempt to acquire unprocessed tenants
)
MAX_TENANTS_PER_POD = int(os.getenv("MAX_TENANTS_PER_POD", 50))

View File

@@ -1,18 +1,34 @@
import asyncio
import os
import signal
import sys
import threading
import time
from threading import Event
from types import FrameType
from typing import Any
from typing import cast
from typing import Dict
from typing import Set
from prometheus_client import Gauge
from prometheus_client import start_http_server
from slack_sdk import WebClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
from danswer.connectors.slack.utils import expert_info_from_slack_id
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
from danswer.danswerbot.slack.config import MAX_TENANTS_PER_POD
from danswer.danswerbot.slack.config import TENANT_ACQUISITION_INTERVAL
from danswer.danswerbot.slack.config import TENANT_HEARTBEAT_EXPIRATION
from danswer.danswerbot.slack.config import TENANT_HEARTBEAT_INTERVAL
from danswer.danswerbot.slack.config import TENANT_LOCK_EXPIRATION
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
@@ -46,6 +62,7 @@ from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
from danswer.danswerbot.slack.utils import rephrase_slack_message
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import TenantSocketModeClient
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import get_session_with_tenant
from danswer.db.search_settings import get_current_search_settings
@@ -53,17 +70,23 @@ from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.one_shot_answer.models import ThreadMessage
from danswer.redis.redis_pool import get_redis_client
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SLACK_CHANNEL_ID
logger = setup_logger()
# Prometheus metric for HPA
active_tenants_gauge = Gauge(
"active_tenants", "Number of active tenants handled by this pod"
)
# In rare cases, some users have been experiencing a massive amount of trivial messages coming through
# to the Slack Bot with trivial messages. Adding this to avoid exploding LLM costs while we track down
# the cause.
@@ -77,10 +100,213 @@ _SLACK_GREETINGS_TO_IGNORE = {
":wave:",
}
# this is always (currently) the user id of Slack's official slackbot
# This is always (currently) the user id of Slack's official slackbot
_OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
class SlackbotHandler:
def __init__(self) -> None:
logger.info("Initializing SlackbotHandler")
self.tenant_ids: Set[str | None] = set()
self.socket_clients: Dict[str | None, TenantSocketModeClient] = {}
self.slack_bot_tokens: Dict[str | None, SlackBotTokens] = {}
self.running = True
self.pod_id = self.get_pod_id()
self._shutdown_event = Event()
logger.info(f"Pod ID: {self.pod_id}")
# Set up signal handlers for graceful shutdown
signal.signal(signal.SIGTERM, self.shutdown)
signal.signal(signal.SIGINT, self.shutdown)
logger.info("Signal handlers registered")
# Start the Prometheus metrics server
logger.info("Starting Prometheus metrics server")
start_http_server(8000)
logger.info("Prometheus metrics server started")
# Start background threads
logger.info("Starting background threads")
self.acquire_thread = threading.Thread(
target=self.acquire_tenants_loop, daemon=True
)
self.heartbeat_thread = threading.Thread(
target=self.heartbeat_loop, daemon=True
)
self.acquire_thread.start()
self.heartbeat_thread.start()
logger.info("Background threads started")
def get_pod_id(self) -> str:
pod_id = os.environ.get("HOSTNAME", "unknown_pod")
logger.info(f"Retrieved pod ID: {pod_id}")
return pod_id
def acquire_tenants_loop(self) -> None:
while not self._shutdown_event.is_set():
try:
self.acquire_tenants()
active_tenants_gauge.set(len(self.tenant_ids))
logger.debug(f"Current active tenants: {len(self.tenant_ids)}")
except Exception as e:
logger.exception(f"Error in Slack acquisition: {e}")
self._shutdown_event.wait(timeout=TENANT_ACQUISITION_INTERVAL)
def heartbeat_loop(self) -> None:
while not self._shutdown_event.is_set():
try:
self.send_heartbeats()
logger.debug(f"Sent heartbeats for {len(self.tenant_ids)} tenants")
except Exception as e:
logger.exception(f"Error in heartbeat loop: {e}")
self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL)
def acquire_tenants(self) -> None:
tenant_ids = get_all_tenant_ids()
logger.debug(f"Found {len(tenant_ids)} total tenants in Postgres")
for tenant_id in tenant_ids:
if tenant_id in self.tenant_ids:
logger.debug(f"Tenant {tenant_id} already in self.tenant_ids")
continue
if len(self.tenant_ids) >= MAX_TENANTS_PER_POD:
logger.info(
f"Max tenants per pod reached ({MAX_TENANTS_PER_POD}) Not acquiring any more tenants"
)
break
redis_client = get_redis_client(tenant_id=tenant_id)
pod_id = self.pod_id
acquired = redis_client.set(
DanswerRedisLocks.SLACK_BOT_LOCK,
pod_id,
nx=True,
ex=TENANT_LOCK_EXPIRATION,
)
if not acquired:
logger.debug(f"Another pod holds the lock for tenant {tenant_id}")
continue
logger.debug(f"Acquired lock for tenant {tenant_id}")
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
tenant_id or POSTGRES_DEFAULT_SCHEMA
)
try:
with get_session_with_tenant(tenant_id) as db_session:
try:
logger.debug(
f"Setting tenant ID context variable for tenant {tenant_id}"
)
slack_bot_tokens = fetch_tokens()
logger.debug(f"Fetched Slack bot tokens for tenant {tenant_id}")
logger.debug(
f"Reset tenant ID context variable for tenant {tenant_id}"
)
if not slack_bot_tokens:
logger.debug(
f"No Slack bot token found for tenant {tenant_id}"
)
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
continue
if (
tenant_id not in self.slack_bot_tokens
or slack_bot_tokens != self.slack_bot_tokens[tenant_id]
):
if tenant_id in self.slack_bot_tokens:
logger.info(
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
)
else:
search_settings = get_current_search_settings(
db_session
)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
self.slack_bot_tokens[tenant_id] = slack_bot_tokens
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
self.start_socket_client(tenant_id, slack_bot_tokens)
except KvKeyNotFoundError:
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
except Exception as e:
logger.exception(f"Error handling tenant {tenant_id}: {e}")
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def send_heartbeats(self) -> None:
current_time = int(time.time())
logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} tenants")
for tenant_id in self.tenant_ids:
redis_client = get_redis_client(tenant_id=tenant_id)
heartbeat_key = (
f"{DanswerRedisLocks.SLACK_BOT_HEARTBEAT_PREFIX}:{self.pod_id}"
)
redis_client.set(
heartbeat_key, current_time, ex=TENANT_HEARTBEAT_EXPIRATION
)
def start_socket_client(
self, tenant_id: str | None, slack_bot_tokens: SlackBotTokens
) -> None:
logger.info(f"Starting socket client for tenant {tenant_id}")
socket_client = _get_socket_client(slack_bot_tokens, tenant_id)
# Append the event handler
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
# Establish a WebSocket connection to the Socket Mode servers
logger.info(f"Connecting socket client for tenant {tenant_id}")
socket_client.connect()
self.socket_clients[tenant_id] = socket_client
self.tenant_ids.add(tenant_id)
logger.info(f"Started SocketModeClient for tenant {tenant_id}")
def stop_socket_clients(self) -> None:
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
for tenant_id, client in self.socket_clients.items():
asyncio.run(client.close())
logger.info(f"Stopped SocketModeClient for tenant {tenant_id}")
def shutdown(self, signum: int | None, frame: FrameType | None) -> None:
if not self.running:
return
logger.info("Shutting down gracefully")
self.running = False
self._shutdown_event.set()
# Stop all socket clients
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
self.stop_socket_clients()
# Wait for background threads to finish (with timeout)
logger.info("Waiting for background threads to finish...")
self.acquire_thread.join(timeout=5)
self.heartbeat_thread.join(timeout=5)
logger.info("Shutdown complete")
sys.exit(0)
def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -> bool:
"""True to keep going, False to ignore this Slack request"""
if req.type == "events_api":
@@ -172,7 +398,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
message_subtype = event.get("subtype")
if message_subtype not in [None, "file_share"]:
channel_specific_logger.info(
f"Ignoring message with subtype '{message_subtype}' since is is a special message type"
f"Ignoring message with subtype '{message_subtype}' since it is a special message type"
)
return False
@@ -247,7 +473,7 @@ def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) ->
)
query_event_id, _, _ = decompose_action_id(feedback_id)
logger.notice(f"Successfully handled QA feedback for event: {query_event_id}")
logger.info(f"Successfully handled QA feedback for event: {query_event_id}")
def build_request_details(
@@ -269,14 +495,14 @@ def build_request_details(
msg = remove_danswer_bot_tag(msg, client=client.web_client)
if DANSWER_BOT_REPHRASE_MESSAGE:
logger.notice(f"Rephrasing Slack message. Original message: {msg}")
logger.info(f"Rephrasing Slack message. Original message: {msg}")
try:
msg = rephrase_slack_message(msg)
logger.notice(f"Rephrased message: {msg}")
logger.info(f"Rephrased message: {msg}")
except Exception as e:
logger.error(f"Error while trying to rephrase the Slack message: {e}")
else:
logger.notice(f"Received Slack message: {msg}")
logger.info(f"Received Slack message: {msg}")
if tagged:
logger.debug("User tagged DanswerBot")
@@ -477,94 +703,21 @@ def _get_socket_client(
)
def _initialize_socket_client(socket_client: TenantSocketModeClient) -> None:
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
# Establish a WebSocket connection to the Socket Mode servers
logger.notice(f"Listening for messages from Slack {socket_client.tenant_id }...")
socket_client.connect()
# Follow the guide (https://docs.danswer.dev/slack_bot_setup) to set up
# the slack bot in your workspace, and then add the bot to any channels you want to
# try and answer questions for. Running this file will setup Danswer to listen to all
# messages in those channels and attempt to answer them. As of now, it will only respond
# to messages sent directly in the channel - it will not respond to messages sent within a
# thread.
#
# NOTE: we are using Web Sockets so that you can run this from within a firewalled VPC
# without issue.
if __name__ == "__main__":
slack_bot_tokens: dict[str | None, SlackBotTokens] = {}
socket_clients: dict[str | None, TenantSocketModeClient] = {}
# Initialize the tenant handler which will manage tenant connections
logger.info("Starting SlackbotHandler")
tenant_handler = SlackbotHandler()
set_is_ee_based_on_env_variable()
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
download_nltk_data()
while True:
try:
tenant_ids = get_all_tenant_ids() # Function to retrieve all tenant IDs
try:
# Keep the main thread alive
while tenant_handler.running:
time.sleep(1)
for tenant_id in tenant_ids:
with get_session_with_tenant(tenant_id) as db_session:
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id or "public")
latest_slack_bot_tokens = fetch_tokens()
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
if (
tenant_id not in slack_bot_tokens
or latest_slack_bot_tokens != slack_bot_tokens[tenant_id]
):
if tenant_id in slack_bot_tokens:
logger.notice(
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
)
else:
# Initial setup for this tenant
search_settings = get_current_search_settings(
db_session
)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
slack_bot_tokens[tenant_id] = latest_slack_bot_tokens
# potentially may cause a message to be dropped, but it is complicated
# to avoid + (1) if the user is changing tokens, they are likely okay with some
# "migration downtime" and (2) if a single message is lost it is okay
# as this should be a very rare occurrence
if tenant_id in socket_clients:
socket_clients[tenant_id].close()
socket_client = _get_socket_client(
latest_slack_bot_tokens, tenant_id
)
# Initialize socket client for this tenant. Each tenant has its own
# socket client, allowing for multiple concurrent connections (one
# per tenant) with the tenant ID wrapped in the socket model client.
# Each `connect` stores websocket connection in a separate thread.
_initialize_socket_client(socket_client)
socket_clients[tenant_id] = socket_client
except KvKeyNotFoundError:
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if tenant_id in socket_clients:
socket_clients[tenant_id].disconnect()
del socket_clients[tenant_id]
del slack_bot_tokens[tenant_id]
# Wait before checking for updates
Event().wait(timeout=60)
except Exception:
logger.exception("An error occurred outside of main event loop")
time.sleep(60)
except Exception:
logger.exception("Fatal error in main thread")
tenant_handler.shutdown(None, None)

View File

@@ -5,20 +5,26 @@ from sqlalchemy import select
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from danswer.auth.api_key import ApiKeyDescriptor
from danswer.auth.api_key import build_displayable_api_key
from danswer.auth.api_key import generate_api_key
from danswer.auth.api_key import hash_api_key
from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
from danswer.db.models import ApiKey
from danswer.db.models import User
from ee.danswer.auth.api_key import ApiKeyDescriptor
from ee.danswer.auth.api_key import build_displayable_api_key
from ee.danswer.auth.api_key import generate_api_key
from ee.danswer.auth.api_key import hash_api_key
from ee.danswer.server.api_key.models import APIKeyArgs
from danswer.server.api_key.models import APIKeyArgs
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
def get_api_key_email_pattern() -> str:
return DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
def is_api_key_email_address(email: str) -> bool:
return email.endswith(f"{DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN}")
return email.endswith(get_api_key_email_pattern())
def fetch_api_keys(db_session: Session) -> list[ApiKeyDescriptor]:
@@ -60,7 +66,11 @@ def insert_api_key(
db_session: Session, api_key_args: APIKeyArgs, user_id: uuid.UUID | None
) -> ApiKeyDescriptor:
std_password_helper = PasswordHelper()
api_key = generate_api_key()
# Get tenant_id from context var (will be default schema for single tenant)
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
api_key = generate_api_key(tenant_id if MULTI_TENANT else None)
api_key_user_id = uuid.uuid4()
display_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER

View File

@@ -14,6 +14,7 @@ from sqlalchemy.orm import Session
from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserRole
from danswer.db.api_key import get_api_key_email_pattern
from danswer.db.engine import get_async_session
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.models import AccessToken
@@ -35,12 +36,16 @@ def get_default_admin_user_emails() -> list[str]:
return get_default_admin_user_emails_fn()
def get_total_users(db_session: Session) -> int:
def get_total_users_count(db_session: Session) -> int:
"""
Returns the total number of users in the system.
This is the sum of users and invited users.
"""
user_count = db_session.query(User).count()
user_count = (
db_session.query(User)
.filter(~User.email.endswith(get_api_key_email_pattern())) # type: ignore
.count()
)
invited_users = len(get_invited_users())
return user_count + invited_users

View File

@@ -25,8 +25,8 @@ from danswer.db.models import UserGroup__ConnectorCredentialPair
from danswer.db.models import UserRole
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit
from ee.danswer.external_permissions.sync_params import check_if_valid_sync_source
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
@@ -351,7 +351,11 @@ def add_credential_to_connector(
raise HTTPException(status_code=404, detail="Connector does not exist")
if access_type == AccessType.SYNC:
if not check_if_valid_sync_source(connector.source):
if not fetch_ee_implementation_or_noop(
"danswer.external_permissions.sync_params",
"check_if_valid_sync_source",
noop_return_value=True,
)(connector.source):
raise HTTPException(
status_code=400,
detail=f"Connector of type {connector.source} does not support SYNC access type",
@@ -438,7 +442,10 @@ def remove_credential_from_connector(
)
if association is not None:
delete_user__ext_group_for_cc_pair__no_commit(
fetch_ee_implementation_or_noop(
"danswer.db.external_perm",
"delete_user__ext_group_for_cc_pair__no_commit",
)(
db_session=db_session,
cc_pair_id=association.id,
)

View File

@@ -10,10 +10,7 @@ from sqlalchemy.sql.expression import or_
from danswer.auth.schemas import UserRole
from danswer.configs.constants import DocumentSource
from danswer.connectors.gmail.constants import (
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_drive.constants import (
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.db.models import ConnectorCredentialPair
@@ -424,25 +421,15 @@ def cleanup_google_drive_credentials(db_session: Session) -> None:
db_session.commit()
def delete_gmail_service_account_credentials(
user: User | None, db_session: Session
def delete_service_account_credentials(
user: User | None, db_session: Session, source: DocumentSource
) -> None:
credentials = fetch_credentials(db_session=db_session, user=user)
for credential in credentials:
if credential.credential_json.get(
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
if (
credential.credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY)
and credential.source == source
):
db_session.delete(credential)
db_session.commit()
def delete_google_drive_service_account_credentials(
user: User | None, db_session: Session
) -> None:
credentials = fetch_credentials(db_session=db_session, user=user)
for credential in credentials:
if credential.credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY):
db_session.delete(credential)
db_session.commit()

View File

@@ -29,6 +29,7 @@ from danswer.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
from danswer.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
from danswer.configs.app_configs import POSTGRES_DB
from danswer.configs.app_configs import POSTGRES_HOST
from danswer.configs.app_configs import POSTGRES_IDLE_SESSIONS_TIMEOUT
from danswer.configs.app_configs import POSTGRES_PASSWORD
from danswer.configs.app_configs import POSTGRES_POOL_PRE_PING
from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE
@@ -37,10 +38,10 @@ from danswer.configs.app_configs import POSTGRES_USER
from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from danswer.utils.logger import setup_logger
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -309,8 +310,12 @@ async def get_async_session_with_tenant(
try:
# Set the search_path to the tenant's schema
await session.execute(text(f'SET search_path = "{tenant_id}"'))
except Exception as e:
logger.error(f"Error setting search_path: {str(e)}")
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
await session.execute(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
except Exception:
logger.exception("Error setting search_path.")
# You can choose to re-raise the exception or handle it
# Here, we'll re-raise to prevent proceeding with an incorrect session
raise
@@ -318,47 +323,77 @@ async def get_async_session_with_tenant(
yield session
@contextmanager
def get_session_with_default_tenant() -> Generator[Session, None, None]:
"""
Get a database session using the current tenant ID from the context variable.
"""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
with get_session_with_tenant(tenant_id) as session:
yield session
@contextmanager
def get_session_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
"""Generate a database session bound to a connection with the appropriate tenant schema set."""
"""
Generate a database session for a specific tenant.
This function:
1. Sets the database schema to the specified tenant's schema.
2. Preserves the tenant ID across the session.
3. Reverts to the previous tenant ID after the session is closed.
4. Uses the default schema if no tenant ID is provided.
"""
engine = get_sqlalchemy_engine()
# Store the previous tenant ID
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
if tenant_id is None:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
else:
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
tenant_id = POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
event.listen(engine, "checkout", set_search_path_on_checkout)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Establish a raw connection
with engine.connect() as connection:
# Access the raw DBAPI connection and set the search_path
dbapi_connection = connection.connection
try:
# Establish a raw connection
with engine.connect() as connection:
# Access the raw DBAPI connection and set the search_path
dbapi_connection = connection.connection
# Set the search_path outside of any transaction
cursor = dbapi_connection.cursor()
try:
cursor.execute(f'SET search_path = "{tenant_id}"')
finally:
cursor.close()
# Bind the session to the connection
with Session(bind=connection, expire_on_commit=False) as session:
# Set the search_path outside of any transaction
cursor = dbapi_connection.cursor()
try:
yield session
cursor.execute(f'SET search_path = "{tenant_id}"')
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
cursor.execute(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
finally:
# Reset search_path to default after the session is used
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()
cursor.close()
# Bind the session to the connection
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
finally:
# Reset search_path to default after the session is used
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()
finally:
# Restore the previous tenant ID
CURRENT_TENANT_ID_CONTEXTVAR.set(previous_tenant_id)
def set_search_path_on_checkout(

View File

@@ -219,7 +219,7 @@ def mark_attempt_partially_succeeded(
def mark_attempt_failed(
index_attempt: IndexAttempt,
index_attempt_id: int,
db_session: Session,
failure_reason: str = "Unknown",
full_exception_trace: str | None = None,
@@ -227,7 +227,7 @@ def mark_attempt_failed(
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt.id)
.where(IndexAttempt.id == index_attempt_id)
.with_for_update()
).scalar_one()

View File

@@ -135,6 +135,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
hidden_assistants: Mapped[list[int]] = mapped_column(
postgresql.JSONB(), nullable=False, default=[]
)
recent_assistants: Mapped[list[dict]] = mapped_column(
postgresql.JSONB(), nullable=False, default=list, server_default="[]"
)
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
TIMESTAMPAware(timezone=True), nullable=True
@@ -734,9 +737,10 @@ class IndexAttempt(Base):
full_exception_trace: Mapped[str | None] = mapped_column(Text, default=None)
# Nullable because in the past, we didn't allow swapping out embedding models live
search_settings_id: Mapped[int] = mapped_column(
ForeignKey("search_settings.id"),
nullable=False,
ForeignKey("search_settings.id", ondelete="SET NULL"),
nullable=True,
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
@@ -756,7 +760,7 @@ class IndexAttempt(Base):
"ConnectorCredentialPair", back_populates="index_attempts"
)
search_settings: Mapped[SearchSettings] = relationship(
search_settings: Mapped[SearchSettings | None] = relationship(
"SearchSettings", back_populates="index_attempts"
)
@@ -918,7 +922,7 @@ class ToolCall(Base):
tool_result: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
message_id: Mapped[int | None] = mapped_column(
ForeignKey("chat_message.id"), nullable=True
ForeignKey("chat_message.id"), nullable=False
)
# Update the relationship
@@ -1061,7 +1065,6 @@ class ChatMessage(Base):
"ToolCall",
back_populates="message",
uselist=False,
# foreign_keys=[ToolCall.message_id], # Specify foreign key if needed
)
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
@@ -1321,7 +1324,6 @@ class StarterMessage(TypedDict):
in Postgres"""
name: str
description: str
message: str

View File

@@ -12,7 +12,7 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_session_with_default_tenant
from danswer.db.llm import fetch_embedding_provider
from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import IndexAttempt
@@ -152,7 +152,7 @@ def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
if db_session is None:
with get_session_with_tenant() as db_session:
with get_session_with_default_tenant() as db_session:
search_settings = get_current_search_settings(db_session)
else:
search_settings = get_current_search_settings(db_session)

View File

@@ -0,0 +1,111 @@
from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.configs.constants import TokenRateLimitScope
from danswer.db.models import TokenRateLimit
from danswer.db.models import TokenRateLimit__UserGroup
from danswer.server.token_rate_limits.models import TokenRateLimitArgs
def fetch_all_user_token_rate_limits(
db_session: Session,
enabled_only: bool = False,
ordered: bool = True,
) -> Sequence[TokenRateLimit]:
query = select(TokenRateLimit).where(
TokenRateLimit.scope == TokenRateLimitScope.USER
)
if enabled_only:
query = query.where(TokenRateLimit.enabled.is_(True))
if ordered:
query = query.order_by(TokenRateLimit.created_at.desc())
return db_session.scalars(query).all()
def fetch_all_global_token_rate_limits(
db_session: Session,
enabled_only: bool = False,
ordered: bool = True,
) -> Sequence[TokenRateLimit]:
query = select(TokenRateLimit).where(
TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
)
if enabled_only:
query = query.where(TokenRateLimit.enabled.is_(True))
if ordered:
query = query.order_by(TokenRateLimit.created_at.desc())
token_rate_limits = db_session.scalars(query).all()
return token_rate_limits
def insert_user_token_rate_limit(
db_session: Session,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = TokenRateLimit(
enabled=token_rate_limit_settings.enabled,
token_budget=token_rate_limit_settings.token_budget,
period_hours=token_rate_limit_settings.period_hours,
scope=TokenRateLimitScope.USER,
)
db_session.add(token_limit)
db_session.commit()
return token_limit
def insert_global_token_rate_limit(
db_session: Session,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = TokenRateLimit(
enabled=token_rate_limit_settings.enabled,
token_budget=token_rate_limit_settings.token_budget,
period_hours=token_rate_limit_settings.period_hours,
scope=TokenRateLimitScope.GLOBAL,
)
db_session.add(token_limit)
db_session.commit()
return token_limit
def update_token_rate_limit(
db_session: Session,
token_rate_limit_id: int,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
if token_limit is None:
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")
token_limit.enabled = token_rate_limit_settings.enabled
token_limit.token_budget = token_rate_limit_settings.token_budget
token_limit.period_hours = token_rate_limit_settings.period_hours
db_session.commit()
return token_limit
def delete_token_rate_limit(
db_session: Session,
token_rate_limit_id: int,
) -> None:
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
if token_limit is None:
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")
db_session.query(TokenRateLimit__UserGroup).filter(
TokenRateLimit__UserGroup.rate_limit_id == token_rate_limit_id
).delete()
db_session.delete(token_limit)
db_session.commit()

View File

@@ -24,6 +24,13 @@ def get_tool_by_id(tool_id: int, db_session: Session) -> Tool:
return tool
def get_tool_by_name(tool_name: str, db_session: Session) -> Tool:
tool = db_session.scalar(select(Tool).where(Tool.name == tool_name))
if not tool:
raise ValueError("Tool by specified name does not exist")
return tool
def create_tool(
name: str,
description: str | None,
@@ -37,7 +44,7 @@ def create_tool(
description=description,
in_code_tool_id=None,
openapi_schema=openapi_schema,
custom_headers=[header.dict() for header in custom_headers]
custom_headers=[header.model_dump() for header in custom_headers]
if custom_headers
else [],
user_id=user_id,

View File

@@ -147,7 +147,7 @@ class VespaIndex(DocumentIndex):
return None
deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate"
logger.info(f"Deploying Vespa application package to {deploy_url}")
logger.notice(f"Deploying Vespa application package to {deploy_url}")
vespa_schema_path = os.path.join(
os.getcwd(), "danswer", "document_index", "vespa", "app_config"

View File

@@ -13,6 +13,7 @@ class ChatFileType(str, Enum):
DOC = "document"
# Plain text only contain the text
PLAIN_TEXT = "plain_text"
CSV = "csv"
class FileDescriptor(TypedDict):

View File

@@ -8,12 +8,13 @@ import requests
from sqlalchemy.orm import Session
from danswer.configs.constants import FileOrigin
from danswer.db.engine import get_session_context_manager
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import ChatMessage
from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.models import FileDescriptor
from danswer.file_store.models import InMemoryChatFile
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
def load_chat_file(
@@ -52,11 +53,11 @@ def load_all_chat_files(
return files
def save_file_from_url(url: str) -> str:
def save_file_from_url(url: str, tenant_id: str) -> str:
"""NOTE: using multiple sessions here, since this is often called
using multithreading. In practice, sharing a session has resulted in
weird errors."""
with get_session_context_manager() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
response = requests.get(url)
response.raise_for_status()
@@ -75,7 +76,10 @@ def save_file_from_url(url: str) -> str:
def save_files_from_urls(urls: list[str]) -> list[str]:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
funcs: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [
(save_file_from_url, (url,)) for url in urls
(save_file_from_url, (url, tenant_id)) for url in urls
]
# Must pass in tenant_id here, since this is called by multithreading
return run_functions_tuples_in_parallel(funcs)

View File

@@ -16,9 +16,9 @@ from danswer.key_value_store.interface import KeyValueStore
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()

View File

@@ -1,3 +1,4 @@
import io
import json
from collections.abc import Callable
from collections.abc import Iterator
@@ -7,6 +8,7 @@ from typing import TYPE_CHECKING
from typing import Union
import litellm # type: ignore
import pandas as pd
import tiktoken
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue
@@ -107,11 +109,10 @@ def translate_danswer_msg_to_langchain(
files: list[InMemoryChatFile] = []
# If the message is a `ChatMessage`, it doesn't have the downloaded files
# attached. Just ignore them for now. Also, OpenAI doesn't allow files to
# be attached to AI messages, so we must remove them
if not isinstance(msg, ChatMessage) and msg.message_type != MessageType.ASSISTANT:
# attached. Just ignore them for now.
if not isinstance(msg, ChatMessage):
files = msg.files
content = build_content_with_imgs(msg.message, files)
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
if msg.message_type == MessageType.SYSTEM:
raise ValueError("System messages are not currently part of history")
@@ -135,6 +136,18 @@ def translate_history_to_basemessages(
return history_basemessages, history_token_counts
def _process_csv_file(file: InMemoryChatFile) -> str:
df = pd.read_csv(io.StringIO(file.content.decode("utf-8")))
csv_preview = df.head().to_string()
file_name_section = (
f"CSV FILE NAME: {file.filename}\n"
if file.filename
else "CSV FILE (NO NAME PROVIDED):\n"
)
return f"{file_name_section}{CODE_BLOCK_PAT.format(csv_preview)}\n\n\n"
def _build_content(
message: str,
files: list[InMemoryChatFile] | None = None,
@@ -145,16 +158,26 @@ def _build_content(
if files
else None
)
if not text_files:
csv_files = (
[file for file in files if file.file_type == ChatFileType.CSV]
if files
else None
)
if not text_files and not csv_files:
return message
final_message_with_files = "FILES:\n\n"
for file in text_files:
for file in text_files or []:
file_content = file.content.decode("utf-8")
file_name_section = f"DOCUMENT: {file.filename}\n" if file.filename else ""
final_message_with_files += (
f"{file_name_section}{CODE_BLOCK_PAT.format(file_content.strip())}\n\n\n"
)
for file in csv_files or []:
final_message_with_files += _process_csv_file(file)
final_message_with_files += message
return final_message_with_files
@@ -164,10 +187,19 @@ def build_content_with_imgs(
message: str,
files: list[InMemoryChatFile] | None = None,
img_urls: list[str] | None = None,
message_type: MessageType = MessageType.USER,
) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type
files = files or []
img_files = [file for file in files if file.file_type == ChatFileType.IMAGE]
# Only include image files for user messages
img_files = (
[file for file in files if file.file_type == ChatFileType.IMAGE]
if message_type == MessageType.USER
else []
)
img_urls = img_urls or []
message_main_content = _build_content(message, files)
if not img_files and not img_urls:

View File

@@ -25,6 +25,7 @@ from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRead
from danswer.auth.schemas import UserUpdate
from danswer.auth.users import auth_backend
from danswer.auth.users import BasicAuthenticationError
from danswer.auth.users import fastapi_users
from danswer.configs.app_configs import APP_API_PREFIX
from danswer.configs.app_configs import APP_HOST
@@ -73,6 +74,9 @@ from danswer.server.manage.search_settings import router as search_settings_rout
from danswer.server.manage.slack_bot import router as slack_bot_management_router
from danswer.server.manage.users import router as user_router
from danswer.server.middleware.latency_logging import add_latency_logging_middleware
from danswer.server.openai_assistants_api.full_openai_assistants_api import (
get_full_openai_assistants_api_router,
)
from danswer.server.query_and_chat.chat_backend import router as chat_router
from danswer.server.query_and_chat.query_backend import (
admin_router as admin_query_router,
@@ -194,7 +198,12 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
def log_http_error(_: Request, exc: Exception) -> JSONResponse:
status_code = getattr(exc, "status_code", 500)
if status_code >= 400:
if isinstance(exc, BasicAuthenticationError):
# For BasicAuthenticationError, just log a brief message without stack trace (almost always spam)
logger.error(f"Authentication failed: {str(exc)}")
elif status_code >= 400:
error_msg = f"{str(exc)}\n"
error_msg += "".join(traceback.format_tb(exc.__traceback__))
logger.error(error_msg)
@@ -220,7 +229,6 @@ def get_application() -> FastAPI:
else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
# Add the custom exception handler
application.add_exception_handler(status.HTTP_400_BAD_REQUEST, log_http_error)
application.add_exception_handler(status.HTTP_401_UNAUTHORIZED, log_http_error)
application.add_exception_handler(status.HTTP_403_FORBIDDEN, log_http_error)
@@ -265,6 +273,9 @@ def get_application() -> FastAPI:
application, token_rate_limit_settings_router
)
include_router_with_global_prefix_prepended(application, indexing_router)
include_router_with_global_prefix_prepended(
application, get_full_openai_assistants_api_router()
)
if AUTH_TYPE == AuthType.DISABLED:
# Server logs this during auth setup verification step
@@ -277,12 +288,14 @@ def get_application() -> FastAPI:
prefix="/auth",
tags=["auth"],
)
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_register_router(UserRead, UserCreate),
prefix="/auth",
tags=["auth"],
)
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_reset_password_router(),

View File

@@ -35,23 +35,31 @@ class BaseTokenizer(ABC):
class TiktokenTokenizer(BaseTokenizer):
_instances: dict[str, "TiktokenTokenizer"] = {}
def __new__(cls, encoding_name: str = "cl100k_base") -> "TiktokenTokenizer":
if encoding_name not in cls._instances:
cls._instances[encoding_name] = super(TiktokenTokenizer, cls).__new__(cls)
return cls._instances[encoding_name]
def __new__(cls, model_name: str) -> "TiktokenTokenizer":
if model_name not in cls._instances:
cls._instances[model_name] = super(TiktokenTokenizer, cls).__new__(cls)
return cls._instances[model_name]
def __init__(self, encoding_name: str = "cl100k_base"):
def __init__(self, model_name: str):
if not hasattr(self, "encoder"):
import tiktoken
self.encoder = tiktoken.get_encoding(encoding_name)
self.encoder = tiktoken.encoding_for_model(model_name)
def encode(self, string: str) -> list[int]:
# this returns no special tokens
# this ignores special tokens that the model is trained on, see encode_ordinary for details
return self.encoder.encode_ordinary(string)
def tokenize(self, string: str) -> list[str]:
return [self.encoder.decode([token]) for token in self.encode(string)]
encoded = self.encode(string)
decoded = [self.encoder.decode([token]) for token in encoded]
if len(decoded) != len(encoded):
logger.warning(
f"OpenAI tokenized length {len(decoded)} does not match encoded length {len(encoded)} for string: {string}"
)
return decoded
def decode(self, tokens: list[int]) -> str:
return self.encoder.decode(tokens)
@@ -74,22 +82,35 @@ class HuggingFaceTokenizer(BaseTokenizer):
return self.encoder.decode(tokens)
_TOKENIZER_CACHE: dict[str, BaseTokenizer] = {}
_TOKENIZER_CACHE: dict[tuple[EmbeddingProvider | None, str | None], BaseTokenizer] = {}
def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
def _check_tokenizer_cache(
model_provider: EmbeddingProvider | None, model_name: str | None
) -> BaseTokenizer:
global _TOKENIZER_CACHE
if tokenizer_name not in _TOKENIZER_CACHE:
if tokenizer_name == "openai":
_TOKENIZER_CACHE[tokenizer_name] = TiktokenTokenizer("cl100k_base")
return _TOKENIZER_CACHE[tokenizer_name]
id_tuple = (model_provider, model_name)
if id_tuple not in _TOKENIZER_CACHE:
if model_provider in [EmbeddingProvider.OPENAI, EmbeddingProvider.AZURE]:
if model_name is None:
raise ValueError(
"model_name is required for OPENAI and AZURE embeddings"
)
_TOKENIZER_CACHE[id_tuple] = TiktokenTokenizer(model_name)
return _TOKENIZER_CACHE[id_tuple]
try:
logger.debug(f"Initializing HuggingFaceTokenizer for: {tokenizer_name}")
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(tokenizer_name)
if model_name is None:
model_name = DOCUMENT_ENCODER_MODEL
logger.debug(f"Initializing HuggingFaceTokenizer for: {model_name}")
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(model_name)
except Exception as primary_error:
logger.error(
f"Error initializing HuggingFaceTokenizer for {tokenizer_name}: {primary_error}"
f"Error initializing HuggingFaceTokenizer for {model_name}: {primary_error}"
)
logger.warning(
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
@@ -98,7 +119,7 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
try:
# Cache this tokenizer name to the default so we don't have to try to load it again
# and fail again
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(
DOCUMENT_ENCODER_MODEL
)
except Exception as fallback_error:
@@ -106,10 +127,10 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
f"Error initializing fallback HuggingFaceTokenizer: {fallback_error}"
)
raise ValueError(
f"Failed to initialize tokenizer for {tokenizer_name} and fallback model"
f"Failed to initialize tokenizer for {model_name} and fallback model"
) from fallback_error
return _TOKENIZER_CACHE[tokenizer_name]
return _TOKENIZER_CACHE[id_tuple]
_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
@@ -118,11 +139,16 @@ _DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
def get_tokenizer(
model_name: str | None, provider_type: EmbeddingProvider | str | None
) -> BaseTokenizer:
# Currently all of the viable models use the same sentencepiece tokenizer
# OpenAI uses a different one but currently it's not supported due to quality issues
# the inconsistent chunking makes using the sentencepiece tokenizer default better for now
# LLM tokenizers are specified by strings
global _DEFAULT_TOKENIZER
if provider_type is not None:
if isinstance(provider_type, str):
try:
provider_type = EmbeddingProvider(provider_type)
except ValueError:
logger.debug(
f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer."
)
return _DEFAULT_TOKENIZER
return _check_tokenizer_cache(provider_type, model_name)
return _DEFAULT_TOKENIZER

View File

@@ -65,7 +65,7 @@ from danswer.tools.tool_implementations.search.search_tool import (
from danswer.tools.tool_runner import ToolCallKickoff
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
from ee.danswer.server.query_and_chat.utils import create_temporary_persona
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
@@ -125,11 +125,11 @@ def stream_answer_objects(
)
temporary_persona: Persona | None = None
if query_req.persona_config is not None:
new_persona = create_temporary_persona(
db_session=db_session, persona_config=query_req.persona_config, user=user
)
temporary_persona = new_persona
temporary_persona = fetch_ee_implementation_or_noop(
"danswer.server.query_and_chat.utils", "create_temporary_persona", None
)(db_session=db_session, persona_config=query_req.persona_config, user=user)
persona = temporary_persona if temporary_persona else chat_session.persona
@@ -253,7 +253,7 @@ def stream_answer_objects(
return_contexts=query_req.return_contexts,
skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation,
)
# won't be any ImageGenerationDisplay responses since that tool is never passed in
# won't be any FileChatDisplay responses since that tool is never passed in
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
# for one-shot flow, don't currently do anything with these
if isinstance(packet, ToolResponse):

View File

@@ -13,6 +13,10 @@ from danswer.prompts.chat_prompts import ADDITIONAL_INFO
from danswer.prompts.chat_prompts import CITATION_REMINDER
from danswer.prompts.constants import CODE_BLOCK_PAT
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
logger = setup_logger()
MOST_BASIC_PROMPT = "You are a helpful AI assistant."
@@ -136,14 +140,23 @@ def find_last_index(lst: list[int], max_prompt_tokens: int) -> int:
before the list exceeds the maximum"""
running_sum = 0
if not lst:
logger.warning("Empty message history passed to find_last_index")
return 0
last_ind = 0
for i in range(len(lst) - 1, -1, -1):
running_sum += lst[i] + _PER_MESSAGE_TOKEN_BUFFER
if running_sum > max_prompt_tokens:
last_ind = i + 1
break
if last_ind >= len(lst):
logger.error(
f"Last message alone is too large! max_prompt_tokens: {max_prompt_tokens}, message_token_counts: {lst}"
)
raise ValueError("Last message alone is too large!")
return last_ind

View File

@@ -0,0 +1,72 @@
import redis
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_connector_stop import RedisConnectorStop
from danswer.redis.redis_pool import get_redis_client
class RedisConnector:
"""Composes several classes to simplify interacting with a connector and its
associated background tasks / associated redis interactions."""
def __init__(self, tenant_id: str | None, id: int) -> None:
self.tenant_id: str | None = tenant_id
self.id: int = id
self.redis: redis.Redis = get_redis_client(tenant_id=tenant_id)
self.stop = RedisConnectorStop(tenant_id, id, self.redis)
self.prune = RedisConnectorPrune(tenant_id, id, self.redis)
self.delete = RedisConnectorDelete(tenant_id, id, self.redis)
def new_index(self, search_settings_id: int) -> RedisConnectorIndex:
return RedisConnectorIndex(
self.tenant_id, self.id, search_settings_id, self.redis
)
@staticmethod
def get_id_from_fence_key(key: str) -> str | None:
"""
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
Args:
key (str): The fence key string.
Returns:
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
"""
parts = key.split("_")
if len(parts) != 3:
return None
object_id = parts[2]
return object_id
@staticmethod
def get_id_from_task_id(task_id: str) -> str | None:
"""
Extracts the object ID from a task ID string.
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
- `objectid` is the ID you want to extract,
- `suffix` is another arbitrary string (e.g., a UUID).
Example:
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
this method will return the string `"1"`.
Args:
task_id (str): The task ID string from which to extract the object ID.
Returns:
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
"""
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
parts = task_id.split("_")
if len(parts) != 3:
return None
object_id = parts[1]
return object_id

View File

@@ -0,0 +1,97 @@
import time
from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import (
construct_document_select_for_connector_credential_pair_by_needs_sync,
)
from danswer.redis.redis_object_helper import RedisObjectHelper
class RedisConnectorCredentialPair(RedisObjectHelper):
"""This class is used to scan documents by cc_pair in the db and collect them into
a unified set for syncing.
It differs from the other redis helpers in that the taskset used spans
all connectors and is not per connector."""
PREFIX = "connectorsync"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, tenant_id: str | None, id: int) -> None:
super().__init__(tenant_id, str(id))
@classmethod
def get_fence_key(cls) -> str:
return RedisConnectorCredentialPair.FENCE_PREFIX
@classmethod
def get_taskset_key(cls) -> str:
return RedisConnectorCredentialPair.TASKSET_PREFIX
@property
def taskset_key(self) -> str:
"""Notice that this is intentionally reusing the same taskset for all
connector syncs"""
# example: connector_taskset
return f"{self.TASKSET_PREFIX}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)

View File

@@ -0,0 +1,145 @@
import time
from datetime import datetime
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import construct_document_select_for_connector_credential_pair
class RedisConnectorDeletionFenceData(BaseModel):
num_tasks: int | None
submitted: datetime
class RedisConnectorDelete:
"""Manages interactions with redis for deletion tasks. Should only be accessed
through RedisConnector."""
PREFIX = "connectordeletion"
FENCE_PREFIX = f"{PREFIX}_fence" # "connectordeletion_fence"
TASKSET_PREFIX = f"{PREFIX}_taskset" # "connectordeletion_taskset"
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
self.id = id
self.redis = redis
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
def taskset_clear(self) -> None:
self.redis.delete(self.taskset_key)
def get_remaining(self) -> int:
# todo: move into fence
remaining = cast(int, self.redis.scard(self.taskset_key))
return remaining
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
@property
def payload(self) -> RedisConnectorDeletionFenceData | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorDeletionFenceData.model_validate_json(
cast(str, fence_str)
)
return payload
def set_fence(self, payload: RedisConnectorDeletionFenceData | None) -> None:
if not payload:
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, payload.model_dump_json())
def _generate_task_id(self) -> str:
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "connectordeletion_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
return f"{self.PREFIX}_{self.id}_{uuid4()}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
lock: redis.lock.Lock,
) -> int | None:
"""Returns None if the cc_pair doesn't exist.
Otherwise, returns an int with the number of generated tasks."""
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self.id), db_session)
if not cc_pair:
return None
stmt = construct_document_select_for_connector_credential_pair(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
custom_task_id = self._generate_task_id()
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
self.redis.sadd(self.taskset_key, custom_task_id)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"document_by_cc_pair_cleanup_task",
kwargs=dict(
document_id=doc.id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=self.tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
@staticmethod
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
taskset_key = f"{RedisConnectorDelete.TASKSET_PREFIX}_{id}"
r.srem(taskset_key, task_id)
return
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
for key in r.scan_iter(RedisConnectorDelete.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"):
r.delete(key)

View File

@@ -0,0 +1,146 @@
from datetime import datetime
from typing import cast
from uuid import uuid4
import redis
from pydantic import BaseModel
class RedisConnectorIndexingFenceData(BaseModel):
index_attempt_id: int | None
started: datetime | None
submitted: datetime
celery_task_id: str | None
class RedisConnectorIndex:
"""Manages interactions with redis for indexing tasks. Should only be accessed
through RedisConnector."""
PREFIX = "connectorindexing"
FENCE_PREFIX = f"{PREFIX}_fence" # "connectorindexing_fence"
GENERATOR_TASK_PREFIX = PREFIX + "+generator" # "connectorindexing+generator_fence"
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # connectorindexing_generator_progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # connectorindexing_generator_complete
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
def __init__(
self,
tenant_id: str | None,
id: int,
search_settings_id: int,
redis: redis.Redis,
) -> None:
self.tenant_id: str | None = tenant_id
self.id = id
self.search_settings_id = search_settings_id
self.redis = redis
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}/{search_settings_id}"
self.generator_progress_key = (
f"{self.GENERATOR_PROGRESS_PREFIX}_{id}/{search_settings_id}"
)
self.generator_complete_key = (
f"{self.GENERATOR_COMPLETE_PREFIX}_{id}/{search_settings_id}"
)
self.generator_lock_key = (
f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}"
)
@classmethod
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
return f"{cls.FENCE_PREFIX}_{cc_pair_id}/{search_settings_id}"
def generate_generator_task_id(self) -> str:
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "connectorindexing+generator_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
return f"{self.GENERATOR_TASK_PREFIX}_{self.id}/{self.search_settings_id}_{uuid4()}"
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
@property
def payload(self) -> RedisConnectorIndexingFenceData | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorIndexingFenceData.model_validate_json(
cast(str, fence_str)
)
return payload
def set_fence(
self,
payload: RedisConnectorIndexingFenceData | None,
) -> None:
if not payload:
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, payload.model_dump_json())
def set_generator_complete(self, payload: int | None) -> None:
if not payload:
self.redis.delete(self.generator_complete_key)
return
self.redis.set(self.generator_complete_key, payload)
def generator_clear(self) -> None:
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
def get_progress(self) -> int | None:
"""Returns None if the key doesn't exist. The"""
# TODO: move into fence?
bytes = self.redis.get(self.generator_progress_key)
if bytes is None:
return None
progress = int(cast(int, bytes))
return progress
def get_completion(self) -> int | None:
# TODO: move into fence?
bytes = self.redis.get(self.generator_complete_key)
if bytes is None:
return None
status = int(cast(int, bytes))
return status
def reset(self) -> None:
self.redis.delete(self.generator_lock_key)
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
self.redis.delete(self.fence_key)
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_LOCK_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
r.delete(key)

View File

@@ -0,0 +1,171 @@
import time
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
class RedisConnectorPrune:
"""Manages interactions with redis for pruning tasks. Should only be accessed
through RedisConnector."""
PREFIX = "connectorpruning"
FENCE_PREFIX = f"{PREFIX}_fence"
# phase 1 - geneartor task and progress signals
GENERATORTASK_PREFIX = f"{PREFIX}+generator" # connectorpruning+generator
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # connectorpruning_generator_progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # connectorpruning_generator_complete
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpruning_taskset
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpruning+sub
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
self.id = id
self.redis = redis
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
self.generator_task_key = f"{self.GENERATORTASK_PREFIX}_{id}"
self.generator_progress_key = f"{self.GENERATOR_PROGRESS_PREFIX}_{id}"
self.generator_complete_key = f"{self.GENERATOR_COMPLETE_PREFIX}_{id}"
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
def taskset_clear(self) -> None:
self.redis.delete(self.taskset_key)
def generator_clear(self) -> None:
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
def get_remaining(self) -> int:
# todo: move into fence
remaining = cast(int, self.redis.scard(self.taskset_key))
return remaining
def get_active_task_count(self) -> int:
"""Count of active pruning tasks"""
count = 0
for key in self.redis.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
count += 1
return count
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
def set_fence(self, value: bool) -> None:
if not value:
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, 0)
@property
def generator_complete(self) -> int | None:
"""the fence payload is an int representing the starting number of
pruning tasks to be processed ... just after the generator completes."""
fence_bytes = self.redis.get(self.generator_complete_key)
if fence_bytes is None:
return None
fence_int = cast(int, fence_bytes)
return fence_int
@generator_complete.setter
def generator_complete(self, payload: int | None) -> None:
"""Set the payload to an int to set the fence, otherwise if None it will
be deleted"""
if payload is None:
self.redis.delete(self.generator_complete_key)
return
self.redis.set(self.generator_complete_key, payload)
def generate_tasks(
self,
documents_to_prune: set[str],
celery_app: Celery,
db_session: Session,
lock: redis.lock.Lock | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self.id), db_session)
if not cc_pair:
return None
for doc_id in documents_to_prune:
current_time = time.monotonic()
if lock and current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.subtask_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
self.redis.sadd(self.taskset_key, custom_task_id)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"document_by_cc_pair_cleanup_task",
kwargs=dict(
document_id=doc_id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=self.tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
@staticmethod
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
taskset_key = f"{RedisConnectorPrune.TASKSET_PREFIX}_{id}"
r.srem(taskset_key, task_id)
return
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
for key in r.scan_iter(RedisConnectorPrune.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPrune.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPrune.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
r.delete(key)

View File

@@ -0,0 +1,34 @@
import redis
class RedisConnectorStop:
"""Manages interactions with redis for stop signaling. Should only be accessed
through RedisConnector."""
FENCE_PREFIX = "connectorstop_fence"
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
self.id: int = id
self.redis = redis
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
def set_fence(self, value: bool) -> None:
if not value:
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, 0)
@staticmethod
def reset_all(r: redis.Redis) -> None:
for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"):
r.delete(key)

View File

@@ -0,0 +1,99 @@
import time
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.db.document_set import construct_document_select_by_docset
from danswer.redis.redis_object_helper import RedisObjectHelper
class RedisDocumentSet(RedisObjectHelper):
PREFIX = "documentset"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, tenant_id: str | None, id: int) -> None:
super().__init__(tenant_id, str(id))
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
def set_fence(self, payload: int | None) -> None:
if payload is None:
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, payload)
@property
def payload(self) -> int | None:
bytes = self.redis.get(self.fence_key)
if bytes is None:
return None
progress = int(cast(int, bytes))
return progress
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
stmt = construct_document_select_by_docset(int(self._id), current_only=False)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
async_results.append(result)
return len(async_results)
def reset(self) -> None:
self.redis.delete(self.taskset_key)
self.redis.delete(self.fence_key)
@staticmethod
def reset_all(r: redis.Redis) -> None:
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
r.delete(key)

View File

@@ -0,0 +1,91 @@
from abc import ABC
from abc import abstractmethod
import redis
from celery import Celery
from redis import Redis
from sqlalchemy.orm import Session
from danswer.redis.redis_pool import get_redis_client
class RedisObjectHelper(ABC):
PREFIX = "base"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, tenant_id: str | None, id: str):
self._tenant_id: str | None = tenant_id
self._id: str = id
self.redis = get_redis_client(tenant_id=tenant_id)
@property
def task_id_prefix(self) -> str:
return f"{self.PREFIX}_{self._id}"
@property
def fence_key(self) -> str:
# example: documentset_fence_1
return f"{self.FENCE_PREFIX}_{self._id}"
@property
def taskset_key(self) -> str:
# example: documentset_taskset_1
return f"{self.TASKSET_PREFIX}_{self._id}"
@staticmethod
def get_id_from_fence_key(key: str) -> str | None:
"""
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
Args:
key (str): The fence key string.
Returns:
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
"""
parts = key.split("_")
if len(parts) != 3:
return None
object_id = parts[2]
return object_id
@staticmethod
def get_id_from_task_id(task_id: str) -> str | None:
"""
Extracts the object ID from a task ID string.
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
- `objectid` is the ID you want to extract,
- `suffix` is another arbitrary string (e.g., a UUID).
Example:
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
this method will return the string `"1"`.
Args:
task_id (str): The task ID string from which to extract the object ID.
Returns:
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
"""
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
parts = task_id.split("_")
if len(parts) != 3:
return None
object_id = parts[1]
return object_id
@abstractmethod
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
pass

Some files were not shown because too many files have changed in this diff Show More