Compare commits

..

71 Commits

Author SHA1 Message Date
pablodanswer
f0a21b74d4 functional minus pagination / time based polling 2024-10-05 16:48:02 -07:00
pablodanswer
5c521a7916 validated 2024-10-05 16:32:48 -07:00
pablodanswer
28e65669b4 add multi tenant alembic (#2589) 2024-10-05 21:59:15 +00:00
pablodanswer
493c3d7314 Add only multi tenant dependency injection (#2588)
* add only dependency injection

* quick typing fix

* additional non-dependency context

* update nits
2024-10-05 21:08:41 +00:00
pablodanswer
b04e9e9b67 Improved api key forms + fix non-submittable azure (#2654) 2024-10-04 19:29:45 -07:00
rkuo-danswer
3755e575a5 harden connections to redis (#2677)
* set broker_connection_retry_on_startup to silence deprecation warning (we're OK with retrying on startup)

* env var for CELERY_BROKER_POOL_LIMIT

* add redis retry on timeout and health check interval

* set socket_keepalive = True

* remove shadow declaration of REDIS_HEALTH_CHECK_INTERVAL, add socket_keepalive_options where possible

* fix mypy complaint

* pass through vars in docker compose

* remove extra '='

* wrap in a try
2024-10-04 16:00:48 +00:00
rkuo-danswer
63655cfbed update_single should be optimized for a single call now (#2671)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-10-04 15:43:04 +00:00
rkuo-danswer
7f788e4b1e bump celery to 5.5.0b4 (#2681) 2024-10-04 05:54:32 +00:00
Chris Weaver
1362d4b583 Allow config of background concurrency (#2648)
* Allow config of background concurrency

* Add comment

* Fix light worker

* use backslashes to continue lines in supervisord with bash

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@danswer.ai>
2024-10-04 00:55:28 +00:00
rkuo-danswer
4f47004d47 disable another flaky assert (#2678) 2024-10-04 00:25:46 +00:00
rkuo-danswer
3fdd233e84 delete directly via selection instead of making multiple calls to get chunk ids and delete each one (#2666) 2024-10-03 01:57:25 +00:00
Yuhong Sun
0c54d9d57d Unstructured Update Copy (#2668) 2024-10-02 17:48:11 -07:00
hagen-danswer
c2088602e1 Implement source testing framework + Slack (#2650)
* Added permission sync tests for Slack

* moved folders

* prune test + mypy

* added wait for indexing to cc_pair creation

* commented out check

* should fix other tests

* added slack channel pool

* fixed everything and mypy

* reduced flake
2024-10-02 23:16:07 +00:00
Chris Weaver
b3c367d09c [tiny] adjust user group sync log (#2664) 2024-10-02 18:01:40 +00:00
pablodanswer
457d32fef0 add clarity around assistants and names (#2663) 2024-10-02 18:00:06 +00:00
pablodanswer
af187c6cfe Better virtualization (#2653) 2024-10-02 11:14:59 -07:00
rkuo-danswer
a0235b7b7b replace trivy download endpoint due to db download flakiness on their en… (#2661)
* disable trivy for the moment due to db download flakiness on their end causing the action to fail

* try hardcoding to amazon registry as others have suggested
2024-10-02 17:13:19 +00:00
pablodanswer
a30de693cb Clean, memoized assistant ordering (#2655)
* updated refresh

* memoization and so on

* nit

* build issue
2024-10-02 16:15:54 +00:00
pablodanswer
07aeea69e7 Dupe welcome modal logic (#2656) 2024-10-01 20:11:39 -07:00
Evan Lohn
bd40328a73 fix typo 2024-10-01 20:10:37 -07:00
Chris Weaver
b8232e0681 Update litellm to fix bedrock models (#2649) 2024-10-01 20:09:57 -07:00
Yuhong Sun
fffb9c155a Redis Cache for KV Store (#2603)
* k

* k

* k

* k
2024-10-01 18:31:18 +00:00
rkuo-danswer
f513c5bbed sync up when checks run with branch protection required checks (#2628) 2024-10-01 17:59:10 +00:00
pablodanswer
9a4e51a18e add default model + minor fixes (#2638)
* add default model + minor fixes

* fix build

* minor additional fix

* build fix
2024-10-01 17:43:43 +00:00
rkuo-danswer
2f2fc08553 raise redis connections and using blocking connection pool for more d… (#2635)
* raise redis connections and using blocking connection pool for more deterministic behavior

* improve comment
2024-10-01 17:27:17 +00:00
pablodanswer
c68c6fdc44 welcome flow 2024-10-01 10:34:53 -07:00
hagen-danswer
834c76e30a Added quotes to project name to handle reserved words (#2639) 2024-10-01 10:32:41 -07:00
rkuo-danswer
ec02665ffa run the nightly tag overnight relative to pacific time (#2637) 2024-10-01 16:36:40 +00:00
pablodanswer
3fa1b18306 update nav link name (#2643)
* update nav link name

* underscore -> dash
2024-10-01 16:34:30 +00:00
Chris Weaver
c9bdf4c443 Update CONTRIBUTING.md 2024-10-01 08:46:25 -07:00
Yuhong Sun
e229d27734 Unstructured UI (#2636)
* checkpoint

* k

* k

* need frontend

* add api key check + ui component

* add proper ports + icons + functions

* k

* k

* k

---------

Co-authored-by: pablodanswer <pablo@danswer.ai>
2024-10-01 04:50:03 +00:00
rkuo-danswer
140c5b3957 don't push integration testing docker images (#2584)
* experiment with build and no push

* use slightly more descriptive and consistent tags and names

* name integration test workflow consistently with other workflows

* put the tag back

* try runs-on s3 backend

* try adding runs-on cache

* add with key

* add a dummy path

* forget about multiline

* maybe we don't need runs-on cache immediately

* lower ram slightly, name test with a version bump

* don't need to explicitly include runs-on/cache for docker caching

* comment out flaky portion of knowledge chat test

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-10-01 01:00:47 +00:00
Chris Weaver
3e511497d2 Fix overflow of prompt library table (#2606) 2024-09-30 15:31:12 +00:00
hagen-danswer
b0056907fb Added permissions syncing for slack (#2602)
* Added permissions syncing for slack

* add no email case handling

* mypy fixes

* frontend

* minor cleanup

* param tweak
2024-09-30 15:14:43 +00:00
Chris Weaver
728a41a35a Add heartbeat to indexing (#2595) 2024-09-29 19:26:40 -07:00
Chris Weaver
ef8dda2d47 Rely on PVC (#2604) 2024-09-29 17:30:39 -07:00
pablodanswer
15283b3140 prevent nextFormStep unless credential fully set up (#2599) 2024-09-29 22:47:45 +00:00
Chris Weaver
e159b2e947 Fix default assistant (#2600)
* Fix default assistant

* Remove log

* Add newline
2024-09-29 22:47:14 +00:00
Jeff Knapp
9155800fab EKS initial deployment (#2154)
Co-authored-by: Chris Weaver <25087905+Weves@users.noreply.github.com>
2024-09-29 15:51:31 -07:00
pablodanswer
a392ef0541 Show transition card if no connectors (#2597)
* show transition card if no connectors

* squash

* update apos
2024-09-29 22:35:41 +00:00
Yuhong Sun
5679f0af61 Minor Query History Fix (#2594) 2024-09-29 10:54:08 -07:00
rkuo-danswer
ff8db71cb5 don't write a nightly tag to the same commit more than once (#2585)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-29 10:36:08 -07:00
hagen-danswer
1cff2b82fd Global Curator Fix + Testing (#2591)
* Global Curator Fix

* test fix
2024-09-28 20:14:39 +00:00
Chris Weaver
50dd3c8beb Add size limit to jira tickets (#2586) 2024-09-28 12:49:13 -07:00
hagen-danswer
66a459234d Minor role display refactor (#2578) 2024-09-27 16:50:03 +00:00
rkuo-danswer
19e57474dc Feature/xenforo (#2497)
* Xenforo forum parser support

* clarify ssl cert reqs

* missed a file

* add isLoadState function, fix up xenforo for data driven connector approach

* fixing a new edge case to skip an unexpected parsed element

* change documentsource to xenforo

* make doc id unique and comment what's happening

* remove stray log line

* address code review

---------

Co-authored-by: sime2408 <simun.sunjic@gmail.com>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-27 16:36:05 +00:00
rkuo-danswer
f9638f2ea5 try user deploy key approach to tagging (#2575)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-27 16:04:55 +00:00
rkuo-danswer
fbf51b70d0 Feature/celery multi (#2470)
* first cut at redis

* some new helper functions for the db

* ignore kombu tables in alembic migrations (used by celery)

* multiline commands for readability, add vespa_metadata_sync queue to worker

* typo fix

* fix returning tuple fields

* add constants

* fix _get_access_for_document

* docstrings!

* fix double function declaration and typing

* fix type hinting

* add a global redis pool

* Add get_document function

* use task_logger in various celery tasks

* add celeryconfig.py to simplify configuration. Will be used in a subsequent commit

* Add celery redis helper. used in a subsequent PR

* kombu warning getting spammy since celery is not self managing its queue in Postgres any more

* add last_modified and last_synced to documents

* fix task naming convention

* use celeryconfig.py

* the big one. adds queues and tasks, updates functions to use the queues with priorities, etc

* change vespa index log line to debug

* mypy fixes

* update alembic migration

* fix fence ordering, rename to "monitor", fix fetch_versioned_implementation call

* mypy

* switch to monotonic time

* fix startup dependencies on redis

* rebase alembic migration

* kombu cleanup - fail silently

* mypy

* add redis_host environment override

* update REDIS_HOST env var in docker-compose.dev.yml

* update the rest of the docker files

* in flight

* harden indexing-status endpoint against db changes happening in the background.  Needs further improvement but OK for now.

* allow no task syncs to run because we create certain objects with no entries but initially marked as out of date

* add back writing to vespa on indexing

* actually working connector deletion

* update contributing guide

* backporting fixes from background_deletion

* renaming cache to cache_volume

* add redis password to various deployments

* try setting up pr testing for helm

* fix indent

* hopefully this release version actually exists

* fix command line option to --chart-dirs

* fetch-depth 0

* edit values.yaml

* try setting ct working directory

* bypass testing only on change for now

* move files and lint them

* update helm testing

* some issues suggest using --config works

* add vespa repo

* add postgresql repo

* increase timeout

* try amd64 runner

* fix redis password reference

* add comment to helm chart testing workflow

* rename helm testing workflow to disable it

* adding clarifying comments

* address code review

* missed a file

* remove commented warning ... just not needed

* fix imports

* refactor to use update_single

* mypy fixes

* add vespa test

* multiple celery workers

* update logs as well and set prefetch multipliers appropriate to the worker intent

* add db refresh to connector deletion

* add some preliminary locking

* organize tasks into separate files

* celery auto associates tasks created inside another task, which bloats the result metadata considerably. trail=False prevents this.

* code review fixes

* move monitor_usergroup_taskset to ee, improve logging

* add multi workers to dev_run_background_jobs.py

* update supervisord with some recommended settings for celery

* name celery workers and shorten dev script prefixing

* add configurable sql alchemy engine settings on startup (needed for various intents like API server, different celery workers and tasks, etc)

* fix comments

* autoscale sqlalchemy pool size to celery concurrency (allow override later?)

* supervisord needs the percent symbols escaped

* use name as primary check, some minor refactoring and type hinting too.

* addressing code review

* fix import

* fix prune_documents_task references

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-27 00:50:55 +00:00
hagen-danswer
b97cc01bb2 Added confluence permission syncing (#2537)
* Added confluence permission syncing

* seperated out group and doc syncing

* minorbugfix and mypy

* added frontend and fixed bug

* Minor refactor

* dealth with confluence rate limits!

* mypy fixes!!!

* addressed yuhong feedback

* primary key fix
2024-09-26 22:10:41 +00:00
rkuo-danswer
6d48fd5d99 clamp retry to max_delay (#2570) 2024-09-26 21:56:46 +00:00
Chris Weaver
1f61447b4b Add open in new tab for custom links (#2568) 2024-09-26 20:01:35 +00:00
rkuo-danswer
deee2b3513 push to docker latest when git tag contains "latest", and tag nightly (#2564)
* comment docker tag latest

* make latest builds contingent on a "latest" keyword in the tag

* v4 checkout

* nightly tag push
2024-09-26 17:40:13 +00:00
hagen-danswer
b73d66c84a Cleaned up foreign key cleanup for user group deletion (#2559)
* cleaned up fk cleanup for user group deletion

* added test for user group deletion
2024-09-26 03:38:01 +00:00
rkuo-danswer
c5a61f4820 Feature/test pruning (#2556)
* add test to exercise pruning

* add prettierignore

* mypy fix

* mypy again

* try getting all the env vars set up correctly

* fix ports and hostnames
2024-09-25 23:34:13 +00:00
pablodanswer
ea4a3cbf86 update folder list (#2563) 2024-09-25 16:25:45 -07:00
rkuo-danswer
166514cedf ssl_ca_certs should default to None, not "". (#2560)
* ssl_ca_certs should default to None, not "".

otherwise, if ssl is enabled it will look for the cert on an empty path and fail.

* mypy fix
2024-09-25 19:56:21 +00:00
pablodanswer
be50ae1e71 flex none (#2558) 2024-09-25 10:19:37 -07:00
pablodanswer
f89504ec53 Update some ux edge cases (#2545)
* update some ux edge cases

* update some formatting / ports
2024-09-25 16:46:43 +00:00
trial-danswer
6b3213b1e4 fix typo (#2543)
* fix typo

* Update EmbeddingFormPage.tsx

---------

Co-authored-by: danswer-trial <danswer-trial@danswer-trials-MacBook-Pro.local>
Co-authored-by: rkuo-danswer <rkuo@danswer.ai>
2024-09-25 01:25:46 +00:00
Chris Weaver
48577bf0e4 Allow = in tag filter (#2548)
* Allow = in tag filter

* Rename func
2024-09-24 21:37:35 +00:00
pablodanswer
c59d1ff0a5 Update merge queue logic (#2554)
* update merge queue logic

* remove space
2024-09-24 18:45:05 +00:00
pablodanswer
ba38dec592 ensure default_assistant passed through 2024-09-24 11:35:19 -07:00
pablodanswer
f5adc3063e Update theming (#2552)
* update theming

* update

* update theming
2024-09-24 18:01:08 +00:00
hagen-danswer
8cfe80c53a Added doc_set__user_group cleanup for user_group deletion (#2551) 2024-09-24 16:09:52 +00:00
ThomaciousD
487250320b fix saml email login upsert issue 2024-09-24 07:42:08 -07:00
rkuo-danswer
c8d13922a9 rename classes and ignore deprecation warnings we mostly don't have c… (#2546)
* rename classes and ignore deprecation warnings we mostly don't have control over

* copy pytest.ini

* ignore CryptographyDeprecationWarning

* fully qualify the warning
2024-09-24 00:21:42 +00:00
rkuo-danswer
cb75449cec Feature/runs on 2 (#2547)
* test self hosted runner

* update more docker builds with self hosted runner

* convert everything to runs-on (except web container)

* try upping the RAM for future flake proofing
2024-09-23 23:46:20 +00:00
rkuo-danswer
b66514cd21 test self hosted runner (#2541)
* test self hosted runner

* update more docker builds with self hosted runner

* convert everything to runs-on (except web container)
2024-09-23 21:57:23 +00:00
Chris Weaver
77650c9ee3 Fix misc tool call errors (#2544)
* Fix misc tool call errors

* Fix middleware
2024-09-23 21:00:48 +00:00
pablodanswer
316b6b99ea Tooling testing (#2533)
* add initial testing

* add custom tool testing

* update ports

* update tests - additional coverage

* update types
2024-09-23 20:09:01 +00:00
Chris Weaver
34c2aa0860 Support svg navigation items (#2542)
* Support SVG nav items

* Handle specifying custom SVGs for navbar

* Add comment

* More comment

* More comment
2024-09-23 13:22:20 -07:00
312 changed files with 39522 additions and 2964 deletions

View File

@@ -7,16 +7,17 @@ on:
env:
REGISTRY_IMAGE: danswer/danswer-backend
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
# TODO: make this a matrix build like the web containers
runs-on:
group: amd64-image-builders
# TODO: investigate a matrix build like the web container
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v2
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -31,7 +32,7 @@ jobs:
run: |
sudo apt-get update
sudo apt-get install -y build-essential
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v5
with:
@@ -41,12 +42,20 @@ jobs:
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.REGISTRY_IMAGE }}:latest
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}

View File

@@ -5,14 +5,18 @@ on:
tags:
- '*'
env:
REGISTRY_IMAGE: danswer/danswer-model-server
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
runs-on:
group: amd64-image-builders
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v2
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -31,13 +35,21 @@ jobs:
platforms: linux/amd64,linux/arm64
push: true
tags: |
danswer/danswer-model-server:${{ github.ref_name }}
danswer/danswer-model-server:latest
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'

View File

@@ -7,7 +7,8 @@ on:
env:
REGISTRY_IMAGE: danswer/danswer-web-server
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build:
runs-on:
@@ -35,7 +36,7 @@ jobs:
images: ${{ env.REGISTRY_IMAGE }}
tags: |
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
type=raw,value=${{ env.REGISTRY_IMAGE }}:latest
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -112,8 +113,16 @@ jobs:
run: |
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'

View File

@@ -1,3 +1,6 @@
# This workflow is set up to be manually triggered via the GitHub Action tab.
# Given a version, it will tag those backend and webserver images as "latest".
name: Tag Latest Version
on:
@@ -9,7 +12,9 @@ on:
jobs:
tag:
runs-on: ubuntu-latest
# See https://runs-on.com/runners/linux/
# use a lower powered instance since this just does i/o to docker hub
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1

View File

@@ -1,19 +1,23 @@
name: Run Integration Tests
name: Run Integration Tests v2
concurrency:
group: Run-Integration-Tests-${{ github.head_ref }}
group: Run-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
merge_group:
pull_request:
branches: [ main ]
branches:
- main
- 'release/**'
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
jobs:
integration-tests:
runs-on: Amd64
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
@@ -27,25 +31,35 @@ jobs:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# NOTE: we don't need to build the Web Docker image since it's not used
# during the IT for now. We have a separate action to verify it builds
# succesfully
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# We don't need to build the Web Docker image since it's not yet used
# in the integration tests. We have a separate action to verify that it builds
# successfully.
- name: Pull Web Docker image
run: |
docker pull danswer/danswer-web-server:latest
docker tag danswer/danswer-web-server:latest danswer/danswer-web-server:it
docker tag danswer/danswer-web-server:latest danswer/danswer-web-server:test
# we use the runs-on cache for docker builds
# in conjunction with runs-on runners, it has better speed and unlimited caching
# https://runs-on.com/caching/s3-cache-for-github-actions/
# https://runs-on.com/caching/docker/
# https://github.com/moby/buildkit#s3-cache-experimental
# images are built and run locally for testing purposes. Not pushed.
- name: Build Backend Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64
tags: danswer/danswer-backend:it
cache-from: type=registry,ref=danswer/danswer-backend:it
cache-to: |
type=registry,ref=danswer/danswer-backend:it,mode=max
type=inline
tags: danswer/danswer-backend:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Model Server Docker image
uses: ./.github/actions/custom-build-and-push
@@ -53,11 +67,11 @@ jobs:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64
tags: danswer/danswer-model-server:it
cache-from: type=registry,ref=danswer/danswer-model-server:it
cache-to: |
type=registry,ref=danswer/danswer-model-server:it,mode=max
type=inline
tags: danswer/danswer-model-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build integration test Docker image
uses: ./.github/actions/custom-build-and-push
@@ -65,11 +79,11 @@ jobs:
context: ./backend
file: ./backend/tests/integration/Dockerfile
platforms: linux/amd64
tags: danswer/integration-test-runner:it
cache-from: type=registry,ref=danswer/integration-test-runner:it
cache-to: |
type=registry,ref=danswer/integration-test-runner:it,mode=max
type=inline
tags: danswer/danswer-integration:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Start Docker containers
run: |
@@ -78,7 +92,7 @@ jobs:
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=it \
IMAGE_TAG=test \
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
id: start_docker
@@ -120,6 +134,7 @@ jobs:
run: |
echo "Running integration tests..."
docker run --rm --network danswer-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
@@ -128,7 +143,9 @@ jobs:
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
danswer/integration-test-runner:it
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
danswer/danswer-integration:test
continue-on-error: true
id: run_tests

View File

@@ -12,7 +12,8 @@ on:
jobs:
lint-test:
runs-on: Amd64
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
# fetch-depth 0 is required for helm/chart-testing-action
steps:

View File

@@ -3,11 +3,14 @@ name: Python Checks
on:
merge_group:
pull_request:
branches: [ main ]
branches:
- main
- 'release/**'
jobs:
mypy-check:
runs-on: ubuntu-latest
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code

View File

@@ -15,10 +15,14 @@ env:
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
# Jira
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
jobs:
connectors-check:
runs-on: ubuntu-latest
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
env:
PYTHONPATH: ./backend

View File

@@ -0,0 +1,58 @@
name: Connector Tests
on:
schedule:
# This cron expression runs the job daily at 16:00 UTC (9am PT)
- cron: "0 16 * * *"
env:
# Bedrock
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
# OpenAI
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
env:
PYTHONPATH: ./backend
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
cache: "pip"
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: |
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/llm
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/embedding
- name: Alert on Failure
if: failure() && github.event_name == 'schedule'
env:
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
run: |
curl -X POST \
-H 'Content-type: application/json' \
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
$SLACK_WEBHOOK

View File

@@ -3,11 +3,14 @@ name: Python Unit Tests
on:
merge_group:
pull_request:
branches: [ main ]
branches:
- main
- 'release/**'
jobs:
backend-check:
runs-on: ubuntu-latest
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
env:
PYTHONPATH: ./backend

View File

@@ -1,6 +1,6 @@
name: Quality Checks PR
concurrency:
group: Quality-Checks-PR-${{ github.head_ref }}
group: Quality-Checks-PR-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
@@ -9,7 +9,8 @@ on:
jobs:
quality-checks:
runs-on: ubuntu-latest
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- uses: actions/checkout@v4
with:

54
.github/workflows/tag-nightly.yml vendored Normal file
View File

@@ -0,0 +1,54 @@
name: Nightly Tag Push
on:
schedule:
- cron: '0 10 * * *' # Runs every day at 2 AM PST / 3 AM PDT / 10 AM UTC
permissions:
contents: write # Allows pushing tags to the repository
jobs:
create-and-push-tag:
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
# actions using GITHUB_TOKEN cannot trigger another workflow, but we do want this to trigger docker pushes
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
# implement here which needs an actual user's deploy key
- name: Checkout code
uses: actions/checkout@v4
with:
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
- name: Set up Git user
run: |
git config user.name "Richard Kuo [bot]"
git config user.email "rkuo[bot]@danswer.ai"
- name: Check for existing nightly tag
id: check_tag
run: |
if git tag --points-at HEAD --list "nightly-latest*" | grep -q .; then
echo "A tag starting with 'nightly-latest' already exists on HEAD."
echo "tag_exists=true" >> $GITHUB_OUTPUT
else
echo "No tag starting with 'nightly-latest' exists on HEAD."
echo "tag_exists=false" >> $GITHUB_OUTPUT
fi
# don't tag again if HEAD already has a nightly-latest tag on it
- name: Create Nightly Tag
if: steps.check_tag.outputs.tag_exists == 'false'
env:
DATE: ${{ github.run_id }}
run: |
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
echo "Creating tag: $TAG_NAME"
git tag $TAG_NAME
- name: Push Tag
if: steps.check_tag.outputs.tag_exists == 'false'
run: |
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
git push origin $TAG_NAME

1
.prettierignore Normal file
View File

@@ -0,0 +1 @@
backend/tests/integration/tests/pruning/website

View File

@@ -22,7 +22,7 @@ Your input is vital to making sure that Danswer moves in the right direction.
Before starting on implementation, please raise a GitHub issue.
And always feel free to message us (Chris Weaver / Yuhong Sun) on
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-2afut44lv-Rw3kSWu6_OmdAXRpCv80DQ) /
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ) /
[Discord](https://discord.gg/TDJ59cGV2X) directly about anything at all.

View File

@@ -9,9 +9,9 @@ from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine
from celery.backends.database.session import ResultModelBase # type: ignore
from sqlalchemy.schema import SchemaItem
from sqlalchemy.sql import text
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
# Alembic Config object
config = context.config
# Interpret the config file for Python logging.
@@ -21,16 +21,26 @@ if config.config_file_name is not None and config.attributes.get(
):
fileConfig(config.config_file_name)
# add your model's MetaData object here
# Add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = [Base.metadata, ResultModelBase.metadata]
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def get_schema_options() -> tuple[str, bool]:
x_args_raw = context.get_x_argument()
x_args = {}
for arg in x_args_raw:
for pair in arg.split(","):
if "=" in pair:
key, value = pair.split("=", 1)
x_args[key] = value
schema_name = x_args.get("schema", "public")
create_schema = x_args.get("create_schema", "true").lower() == "true"
return schema_name, create_schema
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
@@ -54,17 +64,20 @@ def run_migrations_offline() -> None:
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = build_connection_string()
schema, _ = get_schema_options()
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
include_object=include_object,
dialect_opts={"paramstyle": "named"},
version_table_schema=schema,
include_schemas=True,
)
with context.begin_transaction():
@@ -72,22 +85,28 @@ def run_migrations_offline() -> None:
def do_run_migrations(connection: Connection) -> None:
schema, create_schema = get_schema_options()
if create_schema:
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"'))
connection.execute(text("COMMIT"))
connection.execute(text(f'SET search_path TO "{schema}"'))
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore
include_object=include_object,
) # type: ignore
version_table_schema=schema,
include_schemas=True,
compare_type=True,
compare_server_default=True,
)
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
"""In this scenario we need to create an Engine
and associate a connection with the context.
"""
"""Run migrations in 'online' mode."""
connectable = create_async_engine(
build_connection_string(),
poolclass=pool.NullPool,
@@ -101,7 +120,6 @@ async def run_async_migrations() -> None:
def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
asyncio.run(run_async_migrations())

View File

@@ -0,0 +1,46 @@
"""fix_user__external_user_group_id_fk
Revision ID: 46b7a812670f
Revises: f32615f71aeb
Create Date: 2024-09-23 12:58:03.894038
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "46b7a812670f"
down_revision = "f32615f71aeb"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop the existing primary key
op.drop_constraint(
"user__external_user_group_id_pkey",
"user__external_user_group_id",
type_="primary",
)
# Add the new composite primary key
op.create_primary_key(
"user__external_user_group_id_pkey",
"user__external_user_group_id",
["user_id", "external_user_group_id", "cc_pair_id"],
)
def downgrade() -> None:
# Drop the composite primary key
op.drop_constraint(
"user__external_user_group_id_pkey",
"user__external_user_group_id",
type_="primary",
)
# Delete all entries from the table
op.execute("DELETE FROM user__external_user_group_id")
# Recreate the original primary key on user_id
op.create_primary_key(
"user__external_user_group_id_pkey", "user__external_user_group_id", ["user_id"]
)

View File

@@ -9,7 +9,7 @@ import json
from typing import cast
from alembic import op
import sqlalchemy as sa
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.key_value_store.factory import get_kv_store
# revision identifiers, used by Alembic.
revision = "703313b75876"
@@ -54,9 +54,7 @@ def upgrade() -> None:
)
try:
settings_json = cast(
str, get_dynamic_config_store().load("token_budget_settings")
)
settings_json = cast(str, get_kv_store().load("token_budget_settings"))
settings = json.loads(settings_json)
is_enabled = settings.get("enable_token_budget", False)
@@ -71,7 +69,7 @@ def upgrade() -> None:
)
# Delete the dynamic config
get_dynamic_config_store().delete("token_budget_settings")
get_kv_store().delete("token_budget_settings")
except Exception:
# Ignore if the dynamic config is not found

View File

@@ -1,20 +1,20 @@
from typing import cast
from danswer.configs.constants import KV_USER_STORE_KEY
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.dynamic_configs.interface import JSON_ro
from danswer.key_value_store.factory import get_kv_store
from danswer.key_value_store.interface import JSON_ro
from danswer.key_value_store.interface import KvKeyNotFoundError
def get_invited_users() -> list[str]:
try:
store = get_dynamic_config_store()
store = get_kv_store()
return cast(list, store.load(KV_USER_STORE_KEY))
except ConfigNotFoundError:
except KvKeyNotFoundError:
return list()
def write_invited_users(emails: list[str]) -> int:
store = get_dynamic_config_store()
store = get_kv_store()
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
return len(emails)

View File

@@ -4,29 +4,29 @@ from typing import cast
from danswer.auth.schemas import UserRole
from danswer.configs.constants import KV_NO_AUTH_USER_PREFERENCES_KEY
from danswer.dynamic_configs.store import ConfigNotFoundError
from danswer.dynamic_configs.store import DynamicConfigStore
from danswer.key_value_store.store import KeyValueStore
from danswer.key_value_store.store import KvKeyNotFoundError
from danswer.server.manage.models import UserInfo
from danswer.server.manage.models import UserPreferences
def set_no_auth_user_preferences(
store: DynamicConfigStore, preferences: UserPreferences
store: KeyValueStore, preferences: UserPreferences
) -> None:
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.model_dump())
def load_no_auth_user_preferences(store: DynamicConfigStore) -> UserPreferences:
def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
try:
preferences_data = cast(
Mapping[str, Any], store.load(KV_NO_AUTH_USER_PREFERENCES_KEY)
)
return UserPreferences(**preferences_data)
except ConfigNotFoundError:
except KvKeyNotFoundError:
return UserPreferences(chosen_assistants=None, default_model=None)
def fetch_no_auth_user(store: DynamicConfigStore) -> UserInfo:
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:
return UserInfo(
id="__no_auth_user__",
email="anonymous@danswer.ai",

View File

@@ -342,7 +342,6 @@ def get_database_strategy(
strategy = DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
)
return strategy

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,6 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from sqlalchemy.orm import Session
@@ -24,12 +25,12 @@ from danswer.db.models import TaskQueueState
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.db.tasks import get_latest_task_by_type
from danswer.redis.redis_pool import RedisPool
from danswer.redis.redis_pool import get_redis_client
from danswer.server.documents.models import DeletionAttemptSnapshot
from danswer.utils.logger import setup_logger
logger = setup_logger()
redis_pool = RedisPool()
def _get_deletion_status(
@@ -46,7 +47,7 @@ def _get_deletion_status(
rcd = RedisConnectorDeletion(cc_pair.id)
r = redis_pool.get_client()
r = get_redis_client()
if not r.exists(rcd.fence_key):
return None
@@ -69,6 +70,30 @@ def get_deletion_attempt_snapshot(
)
def skip_cc_pair_pruning_by_task(
pruning_task: TaskQueueState | None, db_session: Session
) -> bool:
"""task should be the latest prune task for this cc_pair"""
if not ALLOW_SIMULTANEOUS_PRUNING:
# if only one prune is allowed at any time, then check to see if any prune
# is active
pruning_type_task_name = name_cc_prune_task()
last_pruning_type_task = get_latest_task_by_type(
pruning_type_task_name, db_session
)
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
last_pruning_type_task, db_session
):
return True
if pruning_task and check_task_is_live_and_not_timed_out(pruning_task, db_session):
# if the last task is live right now, we shouldn't start a new one
return True
return False
def should_prune_cc_pair(
connector: Connector, credential: Credential, db_session: Session
) -> bool:
@@ -79,31 +104,26 @@ def should_prune_cc_pair(
connector_id=connector.id, credential_id=credential.id
)
last_pruning_task = get_latest_task(pruning_task_name, db_session)
if skip_cc_pair_pruning_by_task(last_pruning_task, db_session):
return False
current_db_time = get_db_current_time(db_session)
if not last_pruning_task:
# If the connector has never been pruned, then compare vs when the connector
# was created
time_since_initialization = current_db_time - connector.time_created
if time_since_initialization.total_seconds() >= connector.prune_freq:
return True
return False
if not ALLOW_SIMULTANEOUS_PRUNING:
pruning_type_task_name = name_cc_prune_task()
last_pruning_type_task = get_latest_task_by_type(
pruning_type_task_name, db_session
)
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
last_pruning_type_task, db_session
):
return False
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
return False
if not last_pruning_task.start_time:
# if the last prune task hasn't started, we shouldn't start a new one
return False
# if the last prune task has a start time, then compare against it to determine
# if we should start
time_since_last_pruning = current_db_time - last_pruning_task.start_time
return time_since_last_pruning.total_seconds() >= connector.prune_freq
@@ -141,3 +161,30 @@ def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> se
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
return all_connector_doc_ids
def celery_is_listening_to_queue(worker: Any, name: str) -> bool:
"""Checks to see if we're listening to the named queue"""
# how to get a list of queues this worker is listening to
# https://stackoverflow.com/questions/29790523/how-to-determine-which-queues-a-celery-worker-is-consuming-at-runtime
queue_names = list(worker.app.amqp.queues.consume_from.keys())
for queue_name in queue_names:
if queue_name == name:
return True
return False
def celery_is_worker_primary(worker: Any) -> bool:
"""There are multiple approaches that could be taken, but the way we do it is to
check the hostname set for the celery worker, either in celeryconfig.py or on the
command line."""
hostname = worker.hostname
if hostname.startswith("light"):
return False
if hostname.startswith("heavy"):
return False
return True

View File

@@ -1,7 +1,9 @@
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
from danswer.configs.app_configs import CELERY_BROKER_POOL_LIMIT
from danswer.configs.app_configs import CELERY_RESULT_EXPIRES
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY_RESULT_BACKEND
from danswer.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
from danswer.configs.app_configs import REDIS_HOST
from danswer.configs.app_configs import REDIS_PASSWORD
from danswer.configs.app_configs import REDIS_PORT
@@ -9,6 +11,7 @@ from danswer.configs.app_configs import REDIS_SSL
from danswer.configs.app_configs import REDIS_SSL_CA_CERTS
from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
CELERY_SEPARATOR = ":"
@@ -36,12 +39,30 @@ result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PO
# can stall other tasks.
worker_prefetch_multiplier = 4
broker_connection_retry_on_startup = True
broker_pool_limit = CELERY_BROKER_POOL_LIMIT
# redis broker settings
# https://docs.celeryq.dev/projects/kombu/en/stable/reference/kombu.transport.redis.html
broker_transport_options = {
"priority_steps": list(range(len(DanswerCeleryPriority))),
"sep": CELERY_SEPARATOR,
"queue_order_strategy": "priority",
"retry_on_timeout": True,
"health_check_interval": REDIS_HEALTH_CHECK_INTERVAL,
"socket_keepalive": True,
"socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS,
}
# redis backend settings
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
# there doesn't appear to be a way to set socket_keepalive_options on the redis result backend
redis_socket_keepalive = True
redis_retry_on_timeout = True
redis_backend_health_check_interval = REDIS_HEALTH_CHECK_INTERVAL
task_default_priority = DanswerCeleryPriority.MEDIUM
task_acks_late = True

View File

@@ -0,0 +1,132 @@
import redis
from celery import shared_task
from celery.exceptions import SoftTimeLimitExceeded
from celery.utils.log import get_task_logger
from redis import Redis
from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import ObjectDeletedError
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import IndexingStatus
from danswer.db.index_attempt import get_last_attempt
from danswer.db.models import ConnectorCredentialPair
from danswer.db.search_settings import get_current_search_settings
from danswer.redis.redis_pool import get_redis_client
# use this within celery tasks to get celery task specific logging
task_logger = get_task_logger(__name__)
@shared_task(
name="check_for_connector_deletion_task",
soft_time_limit=JOB_TIMEOUT,
trail=False,
)
def check_for_connector_deletion_task() -> None:
r = get_redis_client()
lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
with Session(get_sqlalchemy_engine()) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
try_generate_document_cc_pair_cleanup_tasks(
cc_pair, db_session, r, lock_beat
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception")
finally:
if lock_beat.owned():
lock_beat.release()
def try_generate_document_cc_pair_cleanup_tasks(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Note that syncing can still be required even if the number of sync tasks generated is zero.
Returns None if no syncing is required.
"""
lock_beat.reacquire()
rcd = RedisConnectorDeletion(cc_pair.id)
# don't generate sync tasks if tasks are still pending
if r.exists(rcd.fence_key):
return None
# we need to refresh the state of the object inside the fence
# to avoid a race condition with db.commit/fence deletion
# at the end of this taskset
try:
db_session.refresh(cc_pair)
except ObjectDeletedError:
return None
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
return None
search_settings = get_current_search_settings(db_session)
last_indexing = get_last_attempt(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
search_settings_id=search_settings.id,
db_session=db_session,
)
if last_indexing:
if (
last_indexing.status == IndexingStatus.IN_PROGRESS
or last_indexing.status == IndexingStatus.NOT_STARTED
):
return None
# add tasks to celery and build up the task set to monitor in redis
r.delete(rcd.taskset_key)
# Add all documents that need to be updated into the queue
task_logger.info(
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
)
tasks_generated = rcd.generate_tasks(celery_app, db_session, r, lock_beat)
if tasks_generated is None:
return None
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
task_logger.info(
f"RedisConnectorDeletion.generate_tasks finished. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
r.set(rcd.fence_key, tasks_generated)
return tasks_generated

View File

@@ -0,0 +1,140 @@
#####
# Periodic Tasks
#####
import json
from typing import Any
from celery import shared_task
from celery.contrib.abortable import AbortableTask # type: ignore
from celery.exceptions import TaskRevokedError
from celery.utils.log import get_task_logger
from sqlalchemy import inspect
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import PostgresAdvisoryLocks
from danswer.db.engine import get_sqlalchemy_engine # type: ignore
# use this within celery tasks to get celery task specific logging
task_logger = get_task_logger(__name__)
@shared_task(
name="kombu_message_cleanup_task",
soft_time_limit=JOB_TIMEOUT,
bind=True,
base=AbortableTask,
)
def kombu_message_cleanup_task(self: Any) -> int:
"""Runs periodically to clean up the kombu_message table"""
# we will select messages older than this amount to clean up
KOMBU_MESSAGE_CLEANUP_AGE = 7 # days
KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT = 1000
ctx = {}
ctx["last_processed_id"] = 0
ctx["deleted"] = 0
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
with Session(get_sqlalchemy_engine()) as db_session:
# Exit the task if we can't take the advisory lock
result = db_session.execute(
text("SELECT pg_try_advisory_lock(:id)"),
{"id": PostgresAdvisoryLocks.KOMBU_MESSAGE_CLEANUP_LOCK_ID.value},
).scalar()
if not result:
return 0
while True:
if self.is_aborted():
raise TaskRevokedError("kombu_message_cleanup_task was aborted.")
b = kombu_message_cleanup_task_helper(ctx, db_session)
if not b:
break
db_session.commit()
if ctx["deleted"] > 0:
task_logger.info(
f"Deleted {ctx['deleted']} orphaned messages from kombu_message."
)
return ctx["deleted"]
def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool:
"""
Helper function to clean up old messages from the `kombu_message` table that are no longer relevant.
This function retrieves messages from the `kombu_message` table that are no longer visible and
older than a specified interval. It checks if the corresponding task_id exists in the
`celery_taskmeta` table. If the task_id does not exist, the message is deleted.
Args:
ctx (dict): A context dictionary containing configuration parameters such as:
- 'cleanup_age' (int): The age in days after which messages are considered old.
- 'page_limit' (int): The maximum number of messages to process in one batch.
- 'last_processed_id' (int): The ID of the last processed message to handle pagination.
- 'deleted' (int): A counter to track the number of deleted messages.
db_session (Session): The SQLAlchemy database session for executing queries.
Returns:
bool: Returns True if there are more rows to process, False if not.
"""
inspector = inspect(db_session.bind)
if not inspector:
return False
# With the move to redis as celery's broker and backend, kombu tables may not even exist.
# We can fail silently.
if not inspector.has_table("kombu_message"):
return False
query = text(
"""
SELECT id, timestamp, payload
FROM kombu_message WHERE visible = 'false'
AND timestamp < CURRENT_TIMESTAMP - INTERVAL :interval_days
AND id > :last_processed_id
ORDER BY id
LIMIT :page_limit
"""
)
kombu_messages = db_session.execute(
query,
{
"interval_days": f"{ctx['cleanup_age']} days",
"page_limit": ctx["page_limit"],
"last_processed_id": ctx["last_processed_id"],
},
).fetchall()
if len(kombu_messages) == 0:
return False
for msg in kombu_messages:
payload = json.loads(msg[2])
task_id = payload["headers"]["id"]
# Check if task_id exists in celery_taskmeta
task_exists = db_session.execute(
text("SELECT 1 FROM celery_taskmeta WHERE task_id = :task_id"),
{"task_id": task_id},
).fetchone()
# If task_id does not exist, delete the message
if not task_exists:
result = db_session.execute(
text("DELETE FROM kombu_message WHERE id = :message_id"),
{"message_id": msg[0]},
)
if result.rowcount > 0: # type: ignore
ctx["deleted"] += 1
ctx["last_processed_id"] = msg[0]
return True

View File

@@ -0,0 +1,120 @@
from celery import shared_task
from celery.utils.log import get_task_logger
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.celery_utils import should_prune_cc_pair
from danswer.background.connector_deletion import delete_connector_credential_pair_batch
from danswer.background.task_utils import build_celery_task_wrapper
from danswer.background.task_utils import name_cc_prune_task
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import InputType
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.engine import get_sqlalchemy_engine
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
# use this within celery tasks to get celery task specific logging
task_logger = get_task_logger(__name__)
@shared_task(
name="check_for_prune_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_prune_task() -> None:
"""Runs periodically to check if any prune tasks should be run and adds them
to the queue"""
with Session(get_sqlalchemy_engine()) as db_session:
all_cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in all_cc_pairs:
if should_prune_cc_pair(
connector=cc_pair.connector,
credential=cc_pair.credential,
db_session=db_session,
):
task_logger.info(f"Pruning the {cc_pair.connector.name} connector")
prune_documents_task.apply_async(
kwargs=dict(
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
)
@build_celery_task_wrapper(name_cc_prune_task)
@celery_app.task(name="prune_documents_task", soft_time_limit=JOB_TIMEOUT)
def prune_documents_task(connector_id: int, credential_id: int) -> None:
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
with Session(get_sqlalchemy_engine()) as db_session:
try:
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair:
task_logger.warning(
f"ccpair not found for {connector_id} {credential_id}"
)
return
runnable_connector = instantiate_connector(
db_session,
cc_pair.connector.source,
InputType.PRUNE,
cc_pair.connector.connector_specific_config,
cc_pair.credential,
)
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
runnable_connector
)
all_indexed_document_ids = {
doc.id
for doc in get_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
}
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
if len(doc_ids_to_remove) == 0:
task_logger.info(
f"No docs to prune from {cc_pair.connector.source} connector"
)
return
task_logger.info(
f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector"
)
delete_connector_credential_pair_batch(
document_ids=doc_ids_to_remove,
connector_id=connector_id,
credential_id=credential_id,
document_index=document_index,
)
except Exception as e:
task_logger.exception(
f"Failed to run pruning for connector id {connector_id}."
)
raise e

View File

@@ -0,0 +1,525 @@
import traceback
from typing import cast
import redis
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from celery.utils.log import get_task_logger
from redis import Redis
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_document
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector_credential_pair import add_deletion_failure_message
from danswer.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit,
)
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.document import count_documents_by_needs_sync
from danswer.db.document import get_document
from danswer.db.document import mark_document_as_synced
from danswer.db.document_set import delete_document_set
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from danswer.db.document_set import fetch_document_sets
from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.models import DocumentSet
from danswer.db.models import UserGroup
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import UpdateRequest
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from danswer.utils.variable_functionality import noop_fallback
# use this within celery tasks to get celery task specific logging
task_logger = get_task_logger(__name__)
# celery auto associates tasks created inside another task,
# which bloats the result metadata considerably. trail=False prevents this.
@shared_task(
name="check_for_vespa_sync_task",
soft_time_limit=JOB_TIMEOUT,
trail=False,
)
def check_for_vespa_sync_task() -> None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
r = get_redis_client()
lock_beat = r.lock(
DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
with Session(get_sqlalchemy_engine()) as db_session:
try_generate_stale_document_sync_tasks(db_session, r, lock_beat)
# check if any document sets are not synced
document_set_info = fetch_document_sets(
user_id=None, db_session=db_session, include_outdated=True
)
for document_set, _ in document_set_info:
try_generate_document_set_sync_tasks(
document_set, db_session, r, lock_beat
)
# check if any user groups are not synced
try:
fetch_user_groups = fetch_versioned_implementation(
"danswer.db.user_group", "fetch_user_groups"
)
user_groups = fetch_user_groups(
db_session=db_session, only_up_to_date=False
)
for usergroup in user_groups:
try_generate_user_group_sync_tasks(
usergroup, db_session, r, lock_beat
)
except ModuleNotFoundError:
# Always exceptions on the MIT version, which is expected
pass
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception")
finally:
if lock_beat.owned():
lock_beat.release()
def try_generate_stale_document_sync_tasks(
db_session: Session, r: Redis, lock_beat: redis.lock.Lock
) -> int | None:
# the fence is up, do nothing
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
return None
r.delete(RedisConnectorCredentialPair.get_taskset_key()) # delete the taskset
# add tasks to celery and build up the task set to monitor in redis
stale_doc_count = count_documents_by_needs_sync(db_session)
if stale_doc_count == 0:
return None
task_logger.info(
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
)
task_logger.info("RedisConnector.generate_tasks starting by cc_pair.")
# rkuo: we could technically sync all stale docs in one big pass.
# but I feel it's more understandable to group the docs by cc_pair
total_tasks_generated = 0
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
rc = RedisConnectorCredentialPair(cc_pair.id)
tasks_generated = rc.generate_tasks(celery_app, db_session, r, lock_beat)
if tasks_generated is None:
continue
if tasks_generated == 0:
continue
task_logger.info(
f"RedisConnector.generate_tasks finished for single cc_pair. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
)
total_tasks_generated += tasks_generated
task_logger.info(
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
)
r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated)
return total_tasks_generated
def try_generate_document_set_sync_tasks(
document_set: DocumentSet, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
) -> int | None:
lock_beat.reacquire()
rds = RedisDocumentSet(document_set.id)
# don't generate document set sync tasks if tasks are still pending
if r.exists(rds.fence_key):
return None
# don't generate sync tasks if we're up to date
# race condition with the monitor/cleanup function if we use a cached result!
db_session.refresh(document_set)
if document_set.is_up_to_date:
return None
# add tasks to celery and build up the task set to monitor in redis
r.delete(rds.taskset_key)
task_logger.info(
f"RedisDocumentSet.generate_tasks starting. document_set_id={document_set.id}"
)
# Add all documents that need to be updated into the queue
tasks_generated = rds.generate_tasks(celery_app, db_session, r, lock_beat)
if tasks_generated is None:
return None
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
task_logger.info(
f"RedisDocumentSet.generate_tasks finished. "
f"document_set_id={document_set.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
r.set(rds.fence_key, tasks_generated)
return tasks_generated
def try_generate_user_group_sync_tasks(
usergroup: UserGroup, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
) -> int | None:
lock_beat.reacquire()
rug = RedisUserGroup(usergroup.id)
# don't generate sync tasks if tasks are still pending
if r.exists(rug.fence_key):
return None
# race condition with the monitor/cleanup function if we use a cached result!
db_session.refresh(usergroup)
if usergroup.is_up_to_date:
return None
# add tasks to celery and build up the task set to monitor in redis
r.delete(rug.taskset_key)
# Add all documents that need to be updated into the queue
task_logger.info(
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
)
tasks_generated = rug.generate_tasks(celery_app, db_session, r, lock_beat)
if tasks_generated is None:
return None
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
task_logger.info(
f"RedisUserGroup.generate_tasks finished. "
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
r.set(rug.fence_key, tasks_generated)
return tasks_generated
def monitor_connector_taskset(r: Redis) -> None:
fence_value = r.get(RedisConnectorCredentialPair.get_fence_key())
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = r.scard(RedisConnectorCredentialPair.get_taskset_key())
task_logger.info(
f"Stale document sync progress: remaining={count} initial={initial_count}"
)
if count == 0:
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
task_logger.info(f"Successfully synced stale documents. count={initial_count}")
def monitor_document_set_taskset(
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key)
if document_set_id is None:
task_logger.warning("could not parse document set id from {key}")
return
rds = RedisDocumentSet(document_set_id)
fence_value = r.get(rds.fence_key)
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rds.taskset_key))
task_logger.info(
f"Document set sync progress: document_set_id={document_set_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
document_set = cast(
DocumentSet,
get_document_set_by_id(db_session=db_session, document_set_id=document_set_id),
) # casting since we "know" a document set with this ID exists
if document_set:
if not document_set.connector_credential_pairs:
# if there are no connectors, then delete the document set.
delete_document_set(document_set_row=document_set, db_session=db_session)
task_logger.info(
f"Successfully deleted document set with ID: '{document_set_id}'!"
)
else:
mark_document_set_as_synced(document_set_id, db_session)
task_logger.info(
f"Successfully synced document set with ID: '{document_set_id}'!"
)
r.delete(rds.taskset_key)
r.delete(rds.fence_key)
def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
if cc_pair_id is None:
task_logger.warning("could not parse document set id from {key}")
return
rcd = RedisConnectorDeletion(cc_pair_id)
fence_value = r.get(rcd.fence_key)
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rcd.taskset_key))
task_logger.info(
f"Connector deletion progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
with Session(get_sqlalchemy_engine()) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if not cc_pair:
return
try:
# clean up the rest of the related Postgres entities
# index attempts
delete_index_attempts(
db_session=db_session,
cc_pair_id=cc_pair.id,
)
# document sets
delete_document_set_cc_pair_relationship__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# user groups
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
"danswer.db.user_group",
"delete_user_group_cc_pair_relationship__no_commit",
noop_fallback,
)
cleanup_user_groups(
cc_pair_id=cc_pair.id,
db_session=db_session,
)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# if there are no credentials left, delete the connector
connector = fetch_connector_by_id(
db_session=db_session,
connector_id=cc_pair.connector_id,
)
if not connector or not len(connector.credentials):
task_logger.info(
"Found no credentials left for connector, deleting connector"
)
db_session.delete(connector)
db_session.commit()
except Exception as e:
stack_trace = traceback.format_exc()
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
add_deletion_failure_message(db_session, cc_pair.id, error_message)
task_logger.exception(
f"Failed to run connector_deletion. "
f"connector_id={cc_pair.connector_id} credential_id={cc_pair.credential_id}"
)
raise e
task_logger.info(
f"Successfully deleted connector_credential_pair with connector_id: '{cc_pair.connector_id}' "
f"and credential_id: '{cc_pair.credential_id}'. "
f"Deleted {initial_count} docs."
)
r.delete(rcd.taskset_key)
r.delete(rcd.fence_key)
@shared_task(name="monitor_vespa_sync", soft_time_limit=300)
def monitor_vespa_sync() -> None:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
It scans for fence values and then gets the counts of any associated tasksets.
If the count is 0, that means all tasks finished and we should clean up.
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
do anything too expensive in this function!
"""
r = get_redis_client()
lock_beat = r.lock(
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# prevent overlapping tasks
if not lock_beat.acquire(blocking=False):
return
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
monitor_connector_deletion_taskset(key_bytes, r)
with Session(get_sqlalchemy_engine()) as db_session:
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
monitor_document_set_taskset(key_bytes, r, db_session)
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
monitor_usergroup_taskset = (
fetch_versioned_implementation_with_fallback(
"danswer.background.celery.tasks.vespa.tasks",
"monitor_usergroup_taskset",
noop_fallback,
)
)
monitor_usergroup_taskset(key_bytes, r, db_session)
# uncomment for debugging if needed
# r_celery = celery_app.broker_connection().channel().client
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
# task_logger.warning(f"queue={DanswerCeleryQueues.VESPA_METADATA_SYNC} length={length}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
finally:
if lock_beat.owned():
lock_beat.release()
@shared_task(
name="vespa_metadata_sync_task",
bind=True,
soft_time_limit=45,
time_limit=60,
max_retries=3,
)
def vespa_metadata_sync_task(self: Task, document_id: str) -> bool:
task_logger.info(f"document_id={document_id}")
try:
with Session(get_sqlalchemy_engine()) as db_session:
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
doc = get_document(document_id, db_session)
if not doc:
return False
# document set sync
doc_sets = fetch_document_sets_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
# User group sync
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
update_request = UpdateRequest(
document_ids=[document_id],
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# update Vespa
document_index.update(update_requests=[update_request])
# update db last. Worst case = we crash right before this and
# the sync might repeat again later
mark_document_as_synced(document_id, db_session)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
except Exception as e:
task_logger.exception("Unexpected exception")
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown)
return True

View File

@@ -10,22 +10,38 @@ are multiple connector / credential pairs that have indexed it
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from celery.utils.log import get_task_logger
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_document
from danswer.access.access import get_access_for_documents
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document
from danswer.db.document import get_document_connector_count
from danswer.db.document import get_document_connector_counts
from danswer.db.document import mark_document_as_synced
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.engine import get_sqlalchemy_engine
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import UpdateRequest
from danswer.document_index.interfaces import VespaDocumentFields
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.logger import setup_logger
logger = setup_logger()
# use this within celery tasks to get celery task specific logging
task_logger = get_task_logger(__name__)
_DELETION_BATCH_SIZE = 1000
@@ -108,3 +124,88 @@ def delete_connector_credential_pair_batch(
),
)
db_session.commit()
@shared_task(
name="document_by_cc_pair_cleanup_task",
bind=True,
soft_time_limit=45,
time_limit=60,
max_retries=3,
)
def document_by_cc_pair_cleanup_task(
self: Task, document_id: str, connector_id: int, credential_id: int
) -> bool:
task_logger.info(f"document_id={document_id}")
try:
with Session(get_sqlalchemy_engine()) as db_session:
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
count = get_document_connector_count(db_session, document_id)
if count == 1:
# count == 1 means this is the only remaining cc_pair reference to the doc
# delete it from vespa and the db
document_index.delete_single(doc_id=document_id)
delete_documents_complete__no_commit(
db_session=db_session,
document_ids=[document_id],
)
elif count > 1:
# count > 1 means the document still has cc_pair references
doc = get_document(document_id, db_session)
if not doc:
return False
# the below functions do not include cc_pairs being deleted.
# i.e. they will correctly omit access for the current cc_pair
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
doc_sets = fetch_document_sets_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
fields = VespaDocumentFields(
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
document_index.update_single(document_id, fields=fields)
# there are still other cc_pair references to the doc, so just resync to Vespa
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_id=document_id,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
mark_document_as_synced(document_id, db_session)
else:
pass
# update_docs_last_modified__no_commit(
# db_session=db_session,
# document_ids=[document_id],
# )
db_session.commit()
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
except Exception as e:
task_logger.exception("Unexpected exception")
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown)
return True

View File

@@ -14,6 +14,7 @@ from danswer.configs.app_configs import POLL_CONNECTOR_OFFSET
from danswer.connectors.connector_runner import ConnectorRunner
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import IndexAttemptMetadata
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.engine import get_sqlalchemy_engine
@@ -29,6 +30,7 @@ from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import IndexAttemptSingleton
from danswer.utils.logger import setup_logger
@@ -48,7 +50,7 @@ def _get_connector_runner(
"""
NOTE: `start_time` and `end_time` are only used for poll connectors
Returns an interator of document batches and whether the returned documents
Returns an iterator of document batches and whether the returned documents
are the complete list of existing documents of the connector. If the task
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
"""
@@ -66,12 +68,17 @@ def _get_connector_runner(
logger.exception(f"Unable to instantiate connector due to {e}")
# since we failed to even instantiate the connector, we pause the CCPair since
# it will never succeed
update_connector_credential_pair(
db_session=db_session,
connector_id=attempt.connector_credential_pair.connector.id,
credential_id=attempt.connector_credential_pair.credential.id,
status=ConnectorCredentialPairStatus.PAUSED,
cc_pair = get_connector_credential_pair_from_id(
attempt.connector_credential_pair.id, db_session
)
if cc_pair and cc_pair.status == ConnectorCredentialPairStatus.ACTIVE:
update_connector_credential_pair(
db_session=db_session,
connector_id=attempt.connector_credential_pair.connector.id,
credential_id=attempt.connector_credential_pair.credential.id,
status=ConnectorCredentialPairStatus.PAUSED,
)
raise e
return ConnectorRunner(
@@ -103,15 +110,24 @@ def _run_indexing(
)
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings
search_settings=search_settings,
heartbeat=IndexingHeartbeat(
index_attempt_id=index_attempt.id,
db_session=db_session,
# let the world know we're still making progress after
# every 10 batches
freq=10,
),
)
indexing_pipeline = build_indexing_pipeline(
attempt_id=index_attempt.id,
embedder=embedding_model,
document_index=document_index,
ignore_time_skip=index_attempt.from_beginning
or (search_settings.status == IndexModelStatus.FUTURE),
ignore_time_skip=(
index_attempt.from_beginning
or (search_settings.status == IndexModelStatus.FUTURE)
),
db_session=db_session,
)

View File

@@ -96,14 +96,20 @@ def _should_create_new_indexing(
if last_index.status == IndexingStatus.IN_PROGRESS:
return False
else:
if connector.id == 0: # Ingestion API
if (
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
): # Ingestion API
return False
return True
# If the connector is paused or is the ingestion API, don't index
# NOTE: during an embedding model switch over, the following logic
# is bypassed by the above check for a future model
if not cc_pair.status.is_active() or connector.id == 0:
if (
not cc_pair.status.is_active()
or connector.id == 0
or connector.source == DocumentSource.INGESTION_API
):
return False
if not last_index:
@@ -416,6 +422,7 @@ def update_loop(
warm_up_bi_encoder(
embedding_model=embedding_model,
)
logger.notice("First inference complete.")
client_primary: Client | SimpleJobClient
client_secondary: Client | SimpleJobClient
@@ -444,6 +451,7 @@ def update_loop(
existing_jobs: dict[int, Future | SimpleJob] = {}
logger.notice("Startup complete. Waiting for indexing jobs...")
while True:
start = time.time()
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")

View File

@@ -135,7 +135,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
os.environ.get("POSTGRES_PASSWORD") or "password"
)
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
# defaults to False
@@ -164,13 +164,29 @@ REDIS_DB_NUMBER_CELERY_RESULT_BACKEND = int(
)
REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15)) # broker
# will propagate to both our redis client as well as celery's redis client
REDIS_HEALTH_CHECK_INTERVAL = int(os.environ.get("REDIS_HEALTH_CHECK_INTERVAL", 60))
# our redis client only, not celery's
REDIS_POOL_MAX_CONNECTIONS = int(os.environ.get("REDIS_POOL_MAX_CONNECTIONS", 128))
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
# should be one of "required", "optional", or "none"
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "none")
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", "")
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", None)
CELERY_RESULT_EXPIRES = int(os.environ.get("CELERY_RESULT_EXPIRES", 86400)) # seconds
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#broker-pool-limit
# Setting to None may help when there is a proxy in the way closing idle connections
CELERY_BROKER_POOL_LIMIT_DEFAULT = 10
try:
CELERY_BROKER_POOL_LIMIT = int(
os.environ.get("CELERY_BROKER_POOL_LIMIT", CELERY_BROKER_POOL_LIMIT_DEFAULT)
)
except ValueError:
CELERY_BROKER_POOL_LIMIT = CELERY_BROKER_POOL_LIMIT_DEFAULT
#####
# Connector Configs
#####
@@ -247,6 +263,10 @@ JIRA_CONNECTOR_LABELS_TO_SKIP = [
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
if ignored_tag
]
# Maximum size for Jira tickets in bytes (default: 100KB)
JIRA_CONNECTOR_MAX_TICKET_SIZE = int(
os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024)
)
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
@@ -270,7 +290,7 @@ ALLOW_SIMULTANEOUS_PRUNING = (
os.environ.get("ALLOW_SIMULTANEOUS_PRUNING", "").lower() == "true"
)
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
# This is the maximum rate at which documents are queried for a pruning job. 0 disables the limitation.
MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE = int(
os.environ.get("MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE", 0)
)
@@ -334,12 +354,10 @@ INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL", 0))
# exception without aborting the attempt.
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT", 0))
#####
# Miscellaneous
#####
# File based Key Value store no longer used
DYNAMIC_CONFIG_STORE = "PostgresBackedDynamicConfigStore"
JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
# used to allow the background indexing jobs to use a different embedding
# model server than the API server
@@ -388,3 +406,7 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
ENTERPRISE_EDITION_ENABLED = (
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
)
MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "")

View File

@@ -1,3 +1,5 @@
import platform
import socket
from enum import auto
from enum import Enum
@@ -34,9 +36,12 @@ POSTGRES_WEB_APP_NAME = "web"
POSTGRES_INDEXER_APP_NAME = "indexer"
POSTGRES_CELERY_APP_NAME = "celery"
POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
POSTGRES_CELERY_WORKER_APP_NAME = "celery_worker"
POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"
POSTGRES_DEFAULT_SCHEMA = "public"
# API Keys
DANSWER_API_KEY_PREFIX = "API_KEY__"
@@ -46,6 +51,7 @@ UNNAMED_KEY_PLACEHOLDER = "Unnamed"
# Key-Value store keys
KV_REINDEX_KEY = "needs_reindexing"
KV_SEARCH_SETTINGS = "search_settings"
KV_UNSTRUCTURED_API_KEY = "unstructured_api_key"
KV_USER_STORE_KEY = "INVITED_USERS"
KV_NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
KV_CRED_KEY = "credential_id_{}"
@@ -62,11 +68,13 @@ KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings"
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
class DocumentSource(str, Enum):
# Special case, document passed in via Danswer APIs without specifying a source type
INGESTION_API = "ingestion_api"
FRESHDESK = "freshdesk"
SLACK = "slack"
WEB = "web"
GOOGLE_DRIVE = "google_drive"
@@ -104,6 +112,7 @@ class DocumentSource(str, Enum):
R2 = "r2"
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
OCI_STORAGE = "oci_storage"
XENFORO = "xenforo"
NOT_APPLICABLE = "not_applicable"
@@ -186,6 +195,7 @@ class DanswerCeleryQueues:
class DanswerRedisLocks:
PRIMARY_WORKER = "da_lock:primary_worker"
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
@@ -198,3 +208,13 @@ class DanswerCeleryPriority(int, Enum):
MEDIUM = auto()
LOW = auto()
LOWEST = auto()
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3
if platform.system() == "Darwin":
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPALIVE] = 60 # type: ignore
else:
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60 # type: ignore

View File

@@ -194,8 +194,8 @@ class BlobStorageConnector(LoadConnector, PollConnector):
try:
text = extract_file_text(
name,
BytesIO(downloaded_file),
file_name=name,
break_on_unprocessable=False,
)
batch.append(

View File

@@ -0,0 +1,32 @@
import bs4
def build_confluence_document_id(base_url: str, content_url: str) -> str:
"""For confluence, the document id is the page url for a page based document
or the attachment download url for an attachment based document
Args:
base_url (str): The base url of the Confluence instance
content_url (str): The url of the page or attachment download url
Returns:
str: The document id
"""
return f"{base_url}{content_url}"
def get_used_attachments(text: str) -> list[str]:
"""Parse a Confluence html page to generate a list of current
attachment in used
Args:
text (str): The page content
Returns:
list[str]: List of filenames currently in use by the page text
"""
files_in_used = []
soup = bs4.BeautifulSoup(text, "html.parser")
for attachment in soup.findAll("ri:attachment"):
files_in_used.append(attachment.attrs["ri:filename"])
return files_in_used

View File

@@ -22,6 +22,10 @@ from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.confluence.confluence_utils import (
build_confluence_document_id,
)
from danswer.connectors.confluence.confluence_utils import get_used_attachments
from danswer.connectors.confluence.rate_limit_handler import (
make_confluence_call_handle_rate_limit,
)
@@ -105,24 +109,6 @@ def parse_html_page(text: str, confluence_client: Confluence) -> str:
return format_document_soup(soup)
def get_used_attachments(text: str, confluence_client: Confluence) -> list[str]:
"""Parse a Confluence html page to generate a list of current
attachment in used
Args:
text (str): The page content
confluence_client (Confluence): Confluence client
Returns:
list[str]: List of filename currently in used
"""
files_in_used = []
soup = bs4.BeautifulSoup(text, "html.parser")
for attachment in soup.findAll("ri:attachment"):
files_in_used.append(attachment.attrs["ri:filename"])
return files_in_used
def _comment_dfs(
comments_str: str,
comment_pages: Collection[dict[str, Any]],
@@ -533,7 +519,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
return None
extracted_text = extract_file_text(
attachment["title"], io.BytesIO(response.content), False
io.BytesIO(response.content),
file_name=attachment["title"],
break_on_unprocessable=False,
)
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
logger.warning(
@@ -624,19 +612,22 @@ class ConfluenceConnector(LoadConnector, PollConnector):
page_html = (
page["body"].get("storage", page["body"].get("view", {})).get("value")
)
page_url = self.wiki_base + page["_links"]["webui"]
# The url and the id are the same
page_url = build_confluence_document_id(
self.wiki_base, page["_links"]["webui"]
)
if not page_html:
logger.debug("Page is empty, skipping: %s", page_url)
continue
page_text = parse_html_page(page_html, self.confluence_client)
files_in_used = get_used_attachments(page_html, self.confluence_client)
files_in_used = get_used_attachments(page_html)
attachment_text, unused_page_attachments = self._fetch_attachments(
self.confluence_client, page_id, files_in_used
)
unused_attachments.extend(unused_page_attachments)
page_text += attachment_text
page_text += "\n" + attachment_text if attachment_text else ""
comments_text = self._fetch_comments(self.confluence_client, page_id)
page_text += comments_text
doc_metadata: dict[str, str | list[str]] = {"Wiki Space Name": self.space}
@@ -683,8 +674,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
if time_filter and not time_filter(last_updated):
continue
attachment_url = self._attachment_to_download_link(
self.confluence_client, attachment
# The url and the id are the same
attachment_url = build_confluence_document_id(
self.wiki_base, attachment["_links"]["download"]
)
attachment_content = self._attachment_to_content(
self.confluence_client, attachment

View File

@@ -50,6 +50,12 @@ def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
pass
if retry_after is not None:
if retry_after > 600:
logger.warning(
f"Clamping retry_after from {retry_after} to {max_delay} seconds..."
)
retry_after = max_delay
logger.warning(
f"Rate limit hit. Retrying after {retry_after} seconds..."
)

View File

@@ -9,6 +9,7 @@ from jira.resources import Issue
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
@@ -134,10 +135,18 @@ def fetch_jira_issues_batch(
else extract_text_from_adf(jira.raw["fields"]["description"])
)
comments = _get_comment_strs(jira, comment_email_blacklist)
semantic_rep = f"{description}\n" + "\n".join(
ticket_content = f"{description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments if comment]
)
# Check ticket size
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
logger.info(
f"Skipping {jira.key} because it exceeds the maximum size of "
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
)
continue
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
people = set()
@@ -180,7 +189,7 @@ def fetch_jira_issues_batch(
doc_batch.append(
Document(
id=page_url,
sections=[Section(link=page_url, text=semantic_rep)],
sections=[Section(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=jira.fields.summary,
doc_updated_at=time_str_to_utc(jira.fields.updated),
@@ -236,10 +245,12 @@ class JiraConnector(LoadConnector, PollConnector):
if self.jira_client is None:
raise ConnectorMissingCredentialError("Jira")
# Quote the project name to handle reserved words
quoted_project = f'"{self.jira_project}"'
start_ind = 0
while True:
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
jql=f"project = {self.jira_project}",
jql=f"project = {quoted_project}",
start_index=start_ind,
jira_client=self.jira_client,
batch_size=self.batch_size,
@@ -267,8 +278,10 @@ class JiraConnector(LoadConnector, PollConnector):
"%Y-%m-%d %H:%M"
)
# Quote the project name to handle reserved words
quoted_project = f'"{self.jira_project}"'
jql = (
f"project = {self.jira_project} AND "
f"project = {quoted_project} AND "
f"updated >= '{start_date_str}' AND "
f"updated <= '{end_date_str}'"
)

View File

@@ -97,8 +97,8 @@ class DropboxConnector(LoadConnector, PollConnector):
link = self._get_shared_link(entry.path_display)
try:
text = extract_file_text(
entry.name,
BytesIO(downloaded_file),
file_name=entry.name,
break_on_unprocessable=False,
)
batch.append(

View File

@@ -15,6 +15,7 @@ from danswer.connectors.discourse.connector import DiscourseConnector
from danswer.connectors.document360.connector import Document360Connector
from danswer.connectors.dropbox.connector import DropboxConnector
from danswer.connectors.file.connector import LocalFileConnector
from danswer.connectors.freshdesk.connector import FreshdeskConnector
from danswer.connectors.github.connector import GithubConnector
from danswer.connectors.gitlab.connector import GitlabConnector
from danswer.connectors.gmail.connector import GmailConnector
@@ -42,6 +43,7 @@ from danswer.connectors.slack.load_connector import SlackLoadConnector
from danswer.connectors.teams.connector import TeamsConnector
from danswer.connectors.web.connector import WebConnector
from danswer.connectors.wikipedia.connector import WikipediaConnector
from danswer.connectors.xenforo.connector import XenforoConnector
from danswer.connectors.zendesk.connector import ZendeskConnector
from danswer.connectors.zulip.connector import ZulipConnector
from danswer.db.credentials import backend_update_credential_json
@@ -57,11 +59,13 @@ def identify_connector_class(
input_type: InputType | None = None,
) -> Type[BaseConnector]:
connector_map = {
DocumentSource.FRESHDESK: FreshdeskConnector,
DocumentSource.WEB: WebConnector,
DocumentSource.FILE: LocalFileConnector,
DocumentSource.SLACK: {
InputType.LOAD_STATE: SlackLoadConnector,
InputType.POLL: SlackPollConnector,
InputType.PRUNE: SlackPollConnector,
},
DocumentSource.GITHUB: GithubConnector,
DocumentSource.GMAIL: GmailConnector,
@@ -97,6 +101,7 @@ def identify_connector_class(
DocumentSource.R2: BlobStorageConnector,
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
DocumentSource.OCI_STORAGE: BlobStorageConnector,
DocumentSource.XENFORO: XenforoConnector,
}
connector_by_source = connector_map.get(source, {})

View File

@@ -74,13 +74,14 @@ def _process_file(
)
# Using the PDF reader function directly to pass in password cleanly
elif extension == ".pdf":
elif extension == ".pdf" and pdf_pass is not None:
file_content_raw, file_metadata = read_pdf_file(file=file, pdf_pass=pdf_pass)
else:
file_content_raw = extract_file_text(
file_name=file_name,
file=file,
file_name=file_name,
break_on_unprocessable=True,
)
all_metadata = {**metadata, **file_metadata} if metadata else file_metadata

View File

@@ -0,0 +1,160 @@
import json
from datetime import datetime
from typing import Any
from typing import List
from typing import Optional
import requests
from bs4 import BeautifulSoup # Add this import for HTML parsing
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.logger import setup_logger
logger = setup_logger()
class FreshdeskConnector(PollConnector):
def __init__(
self,
api_key: str | None = None,
domain: str | None = None,
password: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.api_key = api_key
self.domain = domain
self.password = password
self.batch_size = batch_size
def ticket_link(self, tid: int) -> str:
return f"https://{self.domain}.freshdesk.com/helpdesk/tickets/{tid}"
def build_doc_sections_from_ticket(self, ticket: dict) -> List[Section]:
# Use list comprehension for building sections
return [
Section(
link=self.ticket_link(int(ticket["id"])),
text=json.dumps(
{
key: value
for key, value in ticket.items()
if isinstance(value, str)
},
default=str,
),
)
]
def strip_html_tags(self, html: str) -> str:
soup = BeautifulSoup(html, "html.parser")
return soup.get_text()
def load_credentials(self, credentials: dict[str, Any]) -> Optional[dict[str, Any]]:
logger.info("Loading credentials")
self.api_key = credentials.get("freshdesk_api_key")
self.domain = credentials.get("freshdesk_domain")
self.password = credentials.get("freshdesk_password")
return None
def _process_tickets(
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
assert self.api_key is not None
assert self.domain is not None
assert self.password is not None
logger.info("Processing tickets")
if any([self.api_key, self.domain, self.password]) is None:
raise ConnectorMissingCredentialError("freshdesk")
freshdesk_url = (
f"https://{self.domain}.freshdesk.com/api/v2/tickets?include=description"
)
response = requests.get(freshdesk_url, auth=(self.api_key, self.password))
response.raise_for_status() # raises exception when not a 2xx response
if response.status_code != 204:
tickets = json.loads(response.content)
logger.info(f"Fetched {len(tickets)} tickets from Freshdesk API")
doc_batch: List[Document] = []
for ticket in tickets:
# Convert the "created_at", "updated_at", and "due_by" values to ISO 8601 strings
for date_field in ["created_at", "updated_at", "due_by"]:
ticket[date_field] = datetime.fromisoformat(
ticket[date_field]
).strftime("%Y-%m-%d %H:%M:%S")
# Convert all other values to strings
ticket = {
key: str(value) if not isinstance(value, str) else value
for key, value in ticket.items()
}
# Checking for overdue tickets
today = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
ticket["overdue"] = "true" if today > ticket["due_by"] else "false"
# Mapping the status field values
status_mapping = {2: "open", 3: "pending", 4: "resolved", 5: "closed"}
ticket["status"] = status_mapping.get(
ticket["status"], str(ticket["status"])
)
# Stripping HTML tags from the description field
ticket["description"] = self.strip_html_tags(ticket["description"])
# Remove extra white spaces from the description field
ticket["description"] = " ".join(ticket["description"].split())
# Use list comprehension for building sections
sections = self.build_doc_sections_from_ticket(ticket)
created_at = datetime.fromisoformat(ticket["created_at"])
doc = Document(
id=ticket["id"],
sections=sections,
source=DocumentSource.FRESHDESK,
semantic_identifier=ticket["subject"],
metadata={
key: value
for key, value in ticket.items()
if isinstance(value, str)
and key not in ["description", "description_text"]
},
doc_updated_at=created_at,
)
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
def poll_source(
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
yield from self._process_tickets(start, end)
if __name__ == "__main__":
import os
connector = FreshdeskConnector(
api_key=os.environ.get("FRESHDESK_API_KEY"),
domain=os.environ.get("FRESHDESK_DOMAIN"),
password=os.environ.get("FRESHDESK_PASSWORD"),
)
for doc in connector.poll_source(start=None, end=None):
print(doc)

View File

@@ -25,7 +25,7 @@ from danswer.connectors.gmail.constants import (
from danswer.connectors.gmail.constants import SCOPES
from danswer.db.credentials import update_credential_json
from danswer.db.models import User
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.key_value_store.factory import get_kv_store
from danswer.server.documents.models import CredentialBase
from danswer.server.documents.models import GoogleAppCredentials
from danswer.server.documents.models import GoogleServiceAccountKey
@@ -72,7 +72,7 @@ def get_gmail_creds_for_service_account(
def verify_csrf(credential_id: int, state: str) -> None:
csrf = get_dynamic_config_store().load(KV_CRED_KEY.format(str(credential_id)))
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
if csrf != state:
raise PermissionError(
"State from Gmail Connector callback does not match expected"
@@ -80,7 +80,7 @@ def verify_csrf(credential_id: int, state: str) -> None:
def get_gmail_auth_url(credential_id: int) -> str:
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_CRED_KEY))
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
@@ -92,14 +92,14 @@ def get_gmail_auth_url(credential_id: int) -> str:
parsed_url = cast(ParseResult, urlparse(auth_url))
params = parse_qs(parsed_url.query)
get_dynamic_config_store().store(
get_kv_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
return str(auth_url)
def get_auth_url(credential_id: int) -> str:
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_CRED_KEY))
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
@@ -111,7 +111,7 @@ def get_auth_url(credential_id: int) -> str:
parsed_url = cast(ParseResult, urlparse(auth_url))
params = parse_qs(parsed_url.query)
get_dynamic_config_store().store(
get_kv_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
return str(auth_url)
@@ -158,42 +158,40 @@ def build_service_account_creds(
def get_google_app_gmail_cred() -> GoogleAppCredentials:
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_CRED_KEY))
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
return GoogleAppCredentials(**json.loads(creds_str))
def upsert_google_app_gmail_cred(app_credentials: GoogleAppCredentials) -> None:
get_dynamic_config_store().store(
KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True
)
get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True)
def delete_google_app_gmail_cred() -> None:
get_dynamic_config_store().delete(KV_GMAIL_CRED_KEY)
get_kv_store().delete(KV_GMAIL_CRED_KEY)
def get_gmail_service_account_key() -> GoogleServiceAccountKey:
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
return GoogleServiceAccountKey(**json.loads(creds_str))
def upsert_gmail_service_account_key(
service_account_key: GoogleServiceAccountKey,
) -> None:
get_dynamic_config_store().store(
get_kv_store().store(
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
get_dynamic_config_store().store(
get_kv_store().store(
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
def delete_gmail_service_account_key() -> None:
get_dynamic_config_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
get_kv_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
def delete_service_account_key() -> None:
get_dynamic_config_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
get_kv_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)

View File

@@ -36,6 +36,8 @@ from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import docx_to_text
from danswer.file_processing.extract_file_text import pptx_to_text
from danswer.file_processing.extract_file_text import read_pdf_file
from danswer.file_processing.unstructured import get_unstructured_api_key
from danswer.file_processing.unstructured import unstructured_to_text
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
@@ -327,16 +329,24 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
GDriveMimeType.MARKDOWN.value,
]:
return service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
elif mime_type == GDriveMimeType.WORD_DOC.value:
if mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
GDriveMimeType.PDF.value,
]:
response = service.files().get_media(fileId=file["id"]).execute()
return docx_to_text(file=io.BytesIO(response))
elif mime_type == GDriveMimeType.PDF.value:
response = service.files().get_media(fileId=file["id"]).execute()
text, _ = read_pdf_file(file=io.BytesIO(response))
return text
elif mime_type == GDriveMimeType.POWERPOINT.value:
response = service.files().get_media(fileId=file["id"]).execute()
return pptx_to_text(file=io.BytesIO(response))
if get_unstructured_api_key():
return unstructured_to_text(
file=io.BytesIO(response), file_name=file.get("name", file["id"])
)
if mime_type == GDriveMimeType.WORD_DOC.value:
return docx_to_text(file=io.BytesIO(response))
elif mime_type == GDriveMimeType.PDF.value:
text, _ = read_pdf_file(file=io.BytesIO(response))
return text
elif mime_type == GDriveMimeType.POWERPOINT.value:
return pptx_to_text(file=io.BytesIO(response))
return UNSUPPORTED_FILE_TYPE_CONTENT

View File

@@ -28,7 +28,7 @@ from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
from danswer.db.credentials import update_credential_json
from danswer.db.models import User
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.key_value_store.factory import get_kv_store
from danswer.server.documents.models import CredentialBase
from danswer.server.documents.models import GoogleAppCredentials
from danswer.server.documents.models import GoogleServiceAccountKey
@@ -134,7 +134,7 @@ def get_google_drive_creds(
def verify_csrf(credential_id: int, state: str) -> None:
csrf = get_dynamic_config_store().load(KV_CRED_KEY.format(str(credential_id)))
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
if csrf != state:
raise PermissionError(
"State from Google Drive Connector callback does not match expected"
@@ -142,7 +142,7 @@ def verify_csrf(credential_id: int, state: str) -> None:
def get_auth_url(credential_id: int) -> str:
creds_str = str(get_dynamic_config_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
@@ -154,7 +154,7 @@ def get_auth_url(credential_id: int) -> str:
parsed_url = cast(ParseResult, urlparse(auth_url))
params = parse_qs(parsed_url.query)
get_dynamic_config_store().store(
get_kv_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
return str(auth_url)
@@ -202,32 +202,28 @@ def build_service_account_creds(
def get_google_app_cred() -> GoogleAppCredentials:
creds_str = str(get_dynamic_config_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
return GoogleAppCredentials(**json.loads(creds_str))
def upsert_google_app_cred(app_credentials: GoogleAppCredentials) -> None:
get_dynamic_config_store().store(
KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
)
get_kv_store().store(KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True)
def delete_google_app_cred() -> None:
get_dynamic_config_store().delete(KV_GOOGLE_DRIVE_CRED_KEY)
get_kv_store().delete(KV_GOOGLE_DRIVE_CRED_KEY)
def get_service_account_key() -> GoogleServiceAccountKey:
creds_str = str(
get_dynamic_config_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
)
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
return GoogleServiceAccountKey(**json.loads(creds_str))
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
get_dynamic_config_store().store(
get_kv_store().store(
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
def delete_service_account_key() -> None:
get_dynamic_config_store().delete(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
get_kv_store().delete(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)

View File

@@ -40,8 +40,8 @@ def _convert_driveitem_to_document(
driveitem: DriveItem,
) -> Document:
file_text = extract_file_text(
file_name=driveitem.name,
file=io.BytesIO(driveitem.get_content().execute_query().value),
file_name=driveitem.name,
break_on_unprocessable=False,
)

View File

@@ -8,13 +8,12 @@ from typing import cast
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.web import SlackResponse
from danswer.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import IdConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
@@ -23,9 +22,8 @@ from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.slack.utils import expert_info_from_slack_id
from danswer.connectors.slack.utils import get_message_link
from danswer.connectors.slack.utils import make_slack_api_call_logged
from danswer.connectors.slack.utils import make_slack_api_call_paginated
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import make_paginated_slack_api_call_w_retries
from danswer.connectors.slack.utils import make_slack_api_call_w_retries
from danswer.connectors.slack.utils import SlackTextCleaner
from danswer.utils.logger import setup_logger
@@ -38,47 +36,18 @@ MessageType = dict[str, Any]
# list of messages in a thread
ThreadType = list[MessageType]
basic_retry_wrapper = retry_builder()
def _make_paginated_slack_api_call(
call: Callable[..., SlackResponse], **kwargs: Any
) -> Generator[dict[str, Any], None, None]:
return make_slack_api_call_paginated(
basic_retry_wrapper(
make_slack_api_rate_limited(make_slack_api_call_logged(call))
)
)(**kwargs)
def _make_slack_api_call(
call: Callable[..., SlackResponse], **kwargs: Any
) -> SlackResponse:
return basic_retry_wrapper(
make_slack_api_rate_limited(make_slack_api_call_logged(call))
)(**kwargs)
def get_channel_info(client: WebClient, channel_id: str) -> ChannelType:
"""Get information about a channel. Needed to convert channel ID to channel name"""
return _make_slack_api_call(client.conversations_info, channel=channel_id)[0][
"channel"
]
def _get_channels(
def _collect_paginated_channels(
client: WebClient,
exclude_archived: bool,
get_private: bool,
channel_types: list[str],
) -> list[ChannelType]:
channels: list[dict[str, Any]] = []
for result in _make_paginated_slack_api_call(
for result in make_paginated_slack_api_call_w_retries(
client.conversations_list,
exclude_archived=exclude_archived,
# also get private channels the bot is added to
types=["public_channel", "private_channel"]
if get_private
else ["public_channel"],
types=channel_types,
):
channels.extend(result["channels"])
@@ -88,19 +57,38 @@ def _get_channels(
def get_channels(
client: WebClient,
exclude_archived: bool = True,
get_public: bool = True,
get_private: bool = True,
) -> list[ChannelType]:
"""Get all channels in the workspace"""
channels: list[dict[str, Any]] = []
channel_types = []
if get_public:
channel_types.append("public_channel")
if get_private:
channel_types.append("private_channel")
# try getting private channels as well at first
try:
return _get_channels(
client=client, exclude_archived=exclude_archived, get_private=True
channels = _collect_paginated_channels(
client=client,
exclude_archived=exclude_archived,
channel_types=channel_types,
)
except SlackApiError as e:
logger.info(f"Unable to fetch private channels due to - {e}")
logger.info("trying again without private channels")
if get_public:
channel_types = ["public_channel"]
else:
logger.warning("No channels to fetch")
return []
channels = _collect_paginated_channels(
client=client,
exclude_archived=exclude_archived,
channel_types=channel_types,
)
return _get_channels(
client=client, exclude_archived=exclude_archived, get_private=False
)
return channels
def get_channel_messages(
@@ -112,14 +100,14 @@ def get_channel_messages(
"""Get all messages in a channel"""
# join so that the bot can access messages
if not channel["is_member"]:
_make_slack_api_call(
make_slack_api_call_w_retries(
client.conversations_join,
channel=channel["id"],
is_private=channel["is_private"],
)
logger.info(f"Successfully joined '{channel['name']}'")
for result in _make_paginated_slack_api_call(
for result in make_paginated_slack_api_call_w_retries(
client.conversations_history,
channel=channel["id"],
oldest=oldest,
@@ -131,7 +119,7 @@ def get_channel_messages(
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
"""Get all messages in a thread"""
threads: list[MessageType] = []
for result in _make_paginated_slack_api_call(
for result in make_paginated_slack_api_call_w_retries(
client.conversations_replies, channel=channel_id, ts=thread_id
):
threads.extend(result["messages"])
@@ -217,12 +205,17 @@ _DISALLOWED_MSG_SUBTYPES = {
"group_leave",
"group_archive",
"group_unarchive",
"channel_leave",
"channel_name",
"channel_join",
}
def _default_msg_filter(message: MessageType) -> bool:
def default_msg_filter(message: MessageType) -> bool:
# Don't keep messages from bots
if message.get("bot_id") or message.get("app_id"):
if message.get("bot_profile", {}).get("name") == "DanswerConnector":
return False
return True
# Uninformative
@@ -266,14 +259,14 @@ def filter_channels(
]
def get_all_docs(
def _get_all_docs(
client: WebClient,
workspace: str,
channels: list[str] | None = None,
channel_name_regex_enabled: bool = False,
oldest: str | None = None,
latest: str | None = None,
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> Generator[Document, None, None]:
"""Get all documents in the workspace, channel by channel"""
slack_cleaner = SlackTextCleaner(client=client)
@@ -328,7 +321,44 @@ def get_all_docs(
)
class SlackPollConnector(PollConnector):
def _get_all_doc_ids(
client: WebClient,
channels: list[str] | None = None,
channel_name_regex_enabled: bool = False,
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> set[str]:
"""
Get all document ids in the workspace, channel by channel
This is pretty identical to get_all_docs, but it returns a set of ids instead of documents
This makes it an order of magnitude faster than get_all_docs
"""
all_channels = get_channels(client)
filtered_channels = filter_channels(
all_channels, channels, channel_name_regex_enabled
)
all_doc_ids = set()
for channel in filtered_channels:
channel_message_batches = get_channel_messages(
client=client,
channel=channel,
)
for message_batch in channel_message_batches:
for message in message_batch:
if msg_filter_func(message):
continue
# The document id is the channel id and the ts of the first message in the thread
# Since we already have the first message of the thread, we dont have to
# fetch the thread for id retrieval, saving time and API calls
all_doc_ids.add(f"{channel['id']}__{message['ts']}")
return all_doc_ids
class SlackPollConnector(PollConnector, IdConnector):
def __init__(
self,
workspace: str,
@@ -349,6 +379,16 @@ class SlackPollConnector(PollConnector):
self.client = WebClient(token=bot_token)
return None
def retrieve_all_source_ids(self) -> set[str]:
if self.client is None:
raise ConnectorMissingCredentialError("Slack")
return _get_all_doc_ids(
client=self.client,
channels=self.channels,
channel_name_regex_enabled=self.channel_regex_enabled,
)
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
@@ -356,7 +396,7 @@ class SlackPollConnector(PollConnector):
raise ConnectorMissingCredentialError("Slack")
documents: list[Document] = []
for document in get_all_docs(
for document in _get_all_docs(
client=self.client,
workspace=self.workspace,
channels=self.channels,

View File

@@ -10,11 +10,13 @@ from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.web import SlackResponse
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.models import BasicExpertInfo
from danswer.utils.logger import setup_logger
logger = setup_logger()
basic_retry_wrapper = retry_builder()
# number of messages we request per page when fetching paginated slack messages
_SLACK_LIMIT = 900
@@ -34,7 +36,7 @@ def get_message_link(
)
def make_slack_api_call_logged(
def _make_slack_api_call_logged(
call: Callable[..., SlackResponse],
) -> Callable[..., SlackResponse]:
@wraps(call)
@@ -47,7 +49,7 @@ def make_slack_api_call_logged(
return logged_call
def make_slack_api_call_paginated(
def _make_slack_api_call_paginated(
call: Callable[..., SlackResponse],
) -> Callable[..., Generator[dict[str, Any], None, None]]:
"""Wraps calls to slack API so that they automatically handle pagination"""
@@ -116,6 +118,24 @@ def make_slack_api_rate_limited(
return rate_limited_call
def make_slack_api_call_w_retries(
call: Callable[..., SlackResponse], **kwargs: Any
) -> SlackResponse:
return basic_retry_wrapper(
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
)(**kwargs)
def make_paginated_slack_api_call_w_retries(
call: Callable[..., SlackResponse], **kwargs: Any
) -> Generator[dict[str, Any], None, None]:
return _make_slack_api_call_paginated(
basic_retry_wrapper(
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
)
)(**kwargs)
def expert_info_from_slack_id(
user_id: str | None,
client: WebClient,

View File

@@ -0,0 +1,244 @@
"""
This is the XenforoConnector class. It is used to connect to a Xenforo forum and load or update documents from the forum.
To use this class, you need to provide the URL of the Xenforo forum board you want to connect to when creating an instance
of the class. The URL should be a string that starts with 'http://' or 'https://', followed by the domain name of the
forum, followed by the board name. For example:
base_url = 'https://www.example.com/forum/boards/some-topic/'
The `load_from_state` method is used to load documents from the forum. It takes an optional `state` parameter, which
can be used to specify a state from which to start loading documents.
"""
import re
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from urllib.parse import urlparse
import pytz
import requests
from bs4 import BeautifulSoup
from bs4 import Tag
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import datetime_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_title(soup: BeautifulSoup) -> str:
el = soup.find("h1", "p-title-value")
if not el:
return ""
title = el.text
for char in (";", ":", "!", "*", "/", "\\", "?", '"', "<", ">", "|"):
title = title.replace(char, "_")
return title
def get_pages(soup: BeautifulSoup, url: str) -> list[str]:
page_tags = soup.select("li.pageNav-page")
page_numbers = []
for button in page_tags:
if re.match(r"^\d+$", button.text):
page_numbers.append(button.text)
max_pages = int(max(page_numbers, key=int)) if page_numbers else 1
all_pages = []
for x in range(1, int(max_pages) + 1):
all_pages.append(f"{url}page-{x}")
return all_pages
def parse_post_date(post_element: BeautifulSoup) -> datetime:
el = post_element.find("time")
if not isinstance(el, Tag) or "datetime" not in el.attrs:
return datetime.utcfromtimestamp(0).replace(tzinfo=timezone.utc)
date_value = el["datetime"]
# Ensure date_value is a string (if it's a list, take the first element)
if isinstance(date_value, list):
date_value = date_value[0]
post_date = datetime.strptime(date_value, "%Y-%m-%dT%H:%M:%S%z")
return datetime_to_utc(post_date)
def scrape_page_posts(
soup: BeautifulSoup,
page_index: int,
url: str,
initial_run: bool,
start_time: datetime,
) -> list:
title = get_title(soup)
documents = []
for post in soup.find_all("div", class_="message-inner"):
post_date = parse_post_date(post)
if initial_run or post_date > start_time:
el = post.find("div", class_="bbWrapper")
if not el:
continue
post_text = el.get_text(strip=True) + "\n"
author_tag = post.find("a", class_="username")
if author_tag is None:
author_tag = post.find("span", class_="username")
author = author_tag.get_text(strip=True) if author_tag else "Deleted author"
formatted_time = post_date.strftime("%Y-%m-%d %H:%M:%S")
# TODO: if a caller calls this for each page of a thread, it may see the
# same post multiple times if there is a sticky post
# that appears on each page of a thread.
# it's important to generate unique doc id's, so page index is part of the
# id. We may want to de-dupe this stuff inside the indexing service.
document = Document(
id=f"{DocumentSource.XENFORO.value}_{title}_{page_index}_{formatted_time}",
sections=[Section(link=url, text=post_text)],
title=title,
source=DocumentSource.XENFORO,
semantic_identifier=title,
primary_owners=[BasicExpertInfo(display_name=author)],
metadata={
"type": "post",
"author": author,
"time": formatted_time,
},
doc_updated_at=post_date,
)
documents.append(document)
return documents
class XenforoConnector(LoadConnector):
# Class variable to track if the connector has been run before
has_been_run_before = False
def __init__(self, base_url: str) -> None:
self.base_url = base_url
self.initial_run = not XenforoConnector.has_been_run_before
self.start = datetime.utcnow().replace(tzinfo=pytz.utc) - timedelta(days=1)
self.cookies: dict[str, str] = {}
# mimic user browser to avoid being blocked by the website (see: https://www.useragents.me/)
self.headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/121.0.0.0 Safari/537.36"
}
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
if credentials:
logger.warning("Unexpected credentials provided for Xenforo Connector")
return None
def load_from_state(self) -> GenerateDocumentsOutput:
# Standardize URL to always end in /.
if self.base_url[-1] != "/":
self.base_url += "/"
# Remove all extra parameters from the end such as page, post.
matches = ("threads/", "boards/", "forums/")
for each in matches:
if each in self.base_url:
try:
self.base_url = self.base_url[
0 : self.base_url.index(
"/", self.base_url.index(each) + len(each)
)
+ 1
]
except ValueError:
pass
doc_batch: list[Document] = []
all_threads = []
# If the URL contains "boards/" or "forums/", find all threads.
if "boards/" in self.base_url or "forums/" in self.base_url:
pages = get_pages(self.requestsite(self.base_url), self.base_url)
# Get all pages on thread_list_page
for pre_count, thread_list_page in enumerate(pages, start=1):
logger.info(
f"Getting pages from thread_list_page.. Current: {pre_count}/{len(pages)}\r"
)
all_threads += self.get_threads(thread_list_page)
# If the URL contains "threads/", add the thread to the list.
elif "threads/" in self.base_url:
all_threads.append(self.base_url)
# Process all threads
for thread_count, thread_url in enumerate(all_threads, start=1):
soup = self.requestsite(thread_url)
if soup is None:
logger.error(f"Failed to load page: {self.base_url}")
continue
pages = get_pages(soup, thread_url)
# Getting all pages for all threads
for page_index, page in enumerate(pages, start=1):
logger.info(
f"Progress: Page {page_index}/{len(pages)} - Thread {thread_count}/{len(all_threads)}\r"
)
soup_page = self.requestsite(page)
doc_batch.extend(
scrape_page_posts(
soup_page, page_index, thread_url, self.initial_run, self.start
)
)
if doc_batch:
yield doc_batch
# Mark the initial run finished after all threads and pages have been processed
XenforoConnector.has_been_run_before = True
def get_threads(self, url: str) -> list[str]:
soup = self.requestsite(url)
thread_tags = soup.find_all(class_="structItem-title")
base_url = "{uri.scheme}://{uri.netloc}".format(uri=urlparse(url))
threads = []
for x in thread_tags:
y = x.find_all(href=True)
for element in y:
link = element["href"]
if "threads/" in link:
stripped = link[0 : link.rfind("/") + 1]
if base_url + stripped not in threads:
threads.append(base_url + stripped)
return threads
def requestsite(self, url: str) -> BeautifulSoup:
try:
response = requests.get(
url, cookies=self.cookies, headers=self.headers, timeout=10
)
if response.status_code != 200:
logger.error(
f"<{url}> Request Error: {response.status_code} - {response.reason}"
)
return BeautifulSoup(response.text, "html.parser")
except TimeoutError:
logger.error("Timed out Error.")
except Exception as e:
logger.error(f"Error on {url}")
logger.exception(e)
return BeautifulSoup("", "html.parser")
if __name__ == "__main__":
connector = XenforoConnector(
# base_url="https://cassiopaea.org/forum/threads/how-to-change-your-emotional-state.41381/"
base_url="https://xenforo.com/community/threads/whats-new-with-enhanced-search-resource-manager-and-media-gallery-in-xenforo-2-3.220935/"
)
document_batches = connector.load_from_state()
print(next(document_batches))

View File

@@ -160,7 +160,7 @@ def handle_regular_answer(
detail="Slack bot does not support persona config",
)
elif new_message_request.persona_id:
elif new_message_request.persona_id is not None:
persona = cast(
Persona,
fetch_persona_by_id(

View File

@@ -49,7 +49,7 @@ from danswer.danswerbot.slack.utils import rephrase_slack_message
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.search_settings import get_current_search_settings
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.one_shot_answer.models import ThreadMessage
@@ -522,7 +522,7 @@ if __name__ == "__main__":
# Let the handlers run in the background + re-check for token updates every 60 seconds
Event().wait(timeout=60)
except ConfigNotFoundError:
except KvKeyNotFoundError:
# try again every 30 seconds. This is needed since the user may add tokens
# via the UI at any point in the programs lifecycle - if we just allow it to
# fail, then the user will need to restart the containers after adding tokens

View File

@@ -2,7 +2,7 @@ import os
from typing import cast
from danswer.configs.constants import KV_SLACK_BOT_TOKENS_CONFIG_KEY
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.key_value_store.factory import get_kv_store
from danswer.server.manage.models import SlackBotTokens
@@ -13,7 +13,7 @@ def fetch_tokens() -> SlackBotTokens:
if app_token and bot_token:
return SlackBotTokens(app_token=app_token, bot_token=bot_token)
dynamic_config_store = get_dynamic_config_store()
dynamic_config_store = get_kv_store()
return SlackBotTokens(
**cast(dict, dynamic_config_store.load(key=KV_SLACK_BOT_TOKENS_CONFIG_KEY))
)
@@ -22,7 +22,7 @@ def fetch_tokens() -> SlackBotTokens:
def save_tokens(
tokens: SlackBotTokens,
) -> None:
dynamic_config_store = get_dynamic_config_store()
dynamic_config_store = get_kv_store()
dynamic_config_store.store(
key=KV_SLACK_BOT_TOKENS_CONFIG_KEY, val=dict(tokens), encrypt=True
)

View File

@@ -26,9 +26,7 @@ from danswer.db.models import UserRole
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit
from ee.danswer.external_permissions.permission_sync_function_map import (
check_if_valid_sync_source,
)
from ee.danswer.external_permissions.sync_params import check_if_valid_sync_source
logger = setup_logger()

View File

@@ -104,6 +104,18 @@ def construct_document_select_for_connector_credential_pair(
return stmt
def get_document_ids_for_connector_credential_pair(
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
) -> list[str]:
doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
)
return list(db_session.execute(doc_ids_stmt).scalars().all())
def get_documents_for_connector_credential_pair(
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
) -> Sequence[DbDocument]:
@@ -120,8 +132,8 @@ def get_documents_for_connector_credential_pair(
def get_documents_by_ids(
document_ids: list[str],
db_session: Session,
document_ids: list[str],
) -> list[DbDocument]:
stmt = select(DbDocument).where(DbDocument.id.in_(document_ids))
documents = db_session.execute(stmt).scalars().all()

View File

@@ -1,10 +1,18 @@
import contextlib
import contextvars
import re
import threading
import time
from collections.abc import AsyncGenerator
from collections.abc import Generator
from datetime import datetime
from typing import Any
from typing import ContextManager
import jwt
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import event
from sqlalchemy import text
from sqlalchemy.engine import create_engine
@@ -17,6 +25,7 @@ from sqlalchemy.orm import sessionmaker
from danswer.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
from danswer.configs.app_configs import LOG_POSTGRES_LATENCY
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import POSTGRES_DB
from danswer.configs.app_configs import POSTGRES_HOST
from danswer.configs.app_configs import POSTGRES_PASSWORD
@@ -24,27 +33,24 @@ from danswer.configs.app_configs import POSTGRES_POOL_PRE_PING
from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE
from danswer.configs.app_configs import POSTGRES_PORT
from danswer.configs.app_configs import POSTGRES_USER
from danswer.configs.app_configs import SECRET_JWT_KEY
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from danswer.utils.logger import setup_logger
logger = setup_logger()
SYNC_DB_API = "psycopg2"
ASYNC_DB_API = "asyncpg"
POSTGRES_APP_NAME = (
POSTGRES_UNKNOWN_APP_NAME # helps to diagnose open connections in postgres
)
# global so we don't create more than one engine per process
# outside of being best practice, this is needed so we can properly pool
# connections and not create a new pool on every request
_SYNC_ENGINE: Engine | None = None
_ASYNC_ENGINE: AsyncEngine | None = None
SessionFactory: sessionmaker[Session] | None = None
if LOG_POSTGRES_LATENCY:
# Function to log before query execution
@event.listens_for(Engine, "before_cursor_execute")
@@ -108,6 +114,78 @@ def get_db_current_time(db_session: Session) -> datetime:
return result
# Regular expression to validate schema names to prevent SQL injection
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
def is_valid_schema_name(name: str) -> bool:
return SCHEMA_NAME_REGEX.match(name) is not None
class SqlEngine:
"""Class to manage a global SQLAlchemy engine (needed for proper resource control).
Will eventually subsume most of the standalone functions in this file.
Sync only for now.
"""
_engine: Engine | None = None
_lock: threading.Lock = threading.Lock()
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
# Default parameters for engine creation
DEFAULT_ENGINE_KWARGS = {
"pool_size": 40,
"max_overflow": 10,
"pool_pre_ping": POSTGRES_POOL_PRE_PING,
"pool_recycle": POSTGRES_POOL_RECYCLE,
}
def __init__(self) -> None:
pass
@classmethod
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
"""Private helper method to create and return an Engine."""
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync"
)
merged_kwargs = {**cls.DEFAULT_ENGINE_KWARGS, **engine_kwargs}
return create_engine(connection_string, **merged_kwargs)
@classmethod
def init_engine(cls, **engine_kwargs: Any) -> None:
"""Allow the caller to init the engine with extra params. Different clients
such as the API server and different Celery workers and tasks
need different settings.
"""
with cls._lock:
if not cls._engine:
cls._engine = cls._init_engine(**engine_kwargs)
@classmethod
def get_engine(cls) -> Engine:
"""Gets the SQLAlchemy engine. Will init a default engine if init hasn't
already been called. You probably want to init first!
"""
if not cls._engine:
with cls._lock:
if not cls._engine:
cls._engine = cls._init_engine()
return cls._engine
@classmethod
def set_app_name(cls, app_name: str) -> None:
"""Class method to set the app name."""
cls._app_name = app_name
@classmethod
def get_app_name(cls) -> str:
"""Class method to get current app name."""
if not cls._app_name:
return ""
return cls._app_name
def build_connection_string(
*,
db_api: str = ASYNC_DB_API,
@@ -120,69 +198,142 @@ def build_connection_string(
) -> str:
if app_name:
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}?application_name={app_name}"
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
def init_sqlalchemy_engine(app_name: str) -> None:
global POSTGRES_APP_NAME
POSTGRES_APP_NAME = app_name
SqlEngine.set_app_name(app_name)
def get_sqlalchemy_engine() -> Engine:
global _SYNC_ENGINE
if _SYNC_ENGINE is None:
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=POSTGRES_APP_NAME + "_sync"
)
_SYNC_ENGINE = create_engine(
connection_string,
pool_size=5,
max_overflow=0,
pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE,
)
return _SYNC_ENGINE
return SqlEngine.get_engine()
def get_sqlalchemy_async_engine() -> AsyncEngine:
global _ASYNC_ENGINE
if _ASYNC_ENGINE is None:
# underlying asyncpg cannot accept application_name directly in the connection string
# Underlying asyncpg cannot accept application_name directly in the connection string
# https://github.com/MagicStack/asyncpg/issues/798
connection_string = build_connection_string()
_ASYNC_ENGINE = create_async_engine(
connection_string,
connect_args={
"server_settings": {"application_name": POSTGRES_APP_NAME + "_async"}
"server_settings": {
"application_name": SqlEngine.get_app_name() + "_async"
}
},
pool_size=5,
max_overflow=0,
pool_size=40,
max_overflow=10,
pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE,
)
return _ASYNC_ENGINE
def get_session_context_manager() -> ContextManager[Session]:
return contextlib.contextmanager(get_session)()
# Context variable to store the current tenant ID
# This allows us to maintain tenant-specific context throughout the request lifecycle
# The default value is set to POSTGRES_DEFAULT_SCHEMA for non-multi-tenant setups
# This context variable works in both synchronous and asynchronous contexts
# In async code, it's automatically carried across coroutines
# In sync code, it's managed per thread
current_tenant_id = contextvars.ContextVar(
"current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA
)
def get_session() -> Generator[Session, None, None]:
# The line below was added to monitor the latency caused by Postgres connections
# during API calls.
# with tracer.trace("db.get_session"):
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session:
# Dependency to get the current tenant ID and set the context variable
def get_current_tenant_id(request: Request) -> str:
"""Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable."""
if not MULTI_TENANT:
tenant_id = POSTGRES_DEFAULT_SCHEMA
current_tenant_id.set(tenant_id)
return tenant_id
token = request.cookies.get("tenant_details")
if not token:
# If no token is present, use the default schema or handle accordingly
tenant_id = POSTGRES_DEFAULT_SCHEMA
current_tenant_id.set(tenant_id)
return tenant_id
try:
payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"])
tenant_id = payload.get("tenant_id")
if not tenant_id:
raise HTTPException(
status_code=400, detail="Invalid token: tenant_id missing"
)
if not is_valid_schema_name(tenant_id):
raise ValueError("Invalid tenant ID format")
current_tenant_id.set(tenant_id)
return tenant_id
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token format")
except ValueError as e:
# Let the 400 error bubble up
raise HTTPException(status_code=400, detail=str(e))
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")
def get_session_with_tenant(tenant_id: str | None = None) -> Session:
if tenant_id is None:
tenant_id = current_tenant_id.get()
if not is_valid_schema_name(tenant_id):
raise Exception("Invalid tenant ID")
engine = SqlEngine.get_engine()
session = Session(engine, expire_on_commit=False)
@event.listens_for(session, "after_begin")
def set_search_path(session: Session, transaction: Any, connection: Any) -> None:
connection.execute(text("SET search_path TO :schema"), {"schema": tenant_id})
return session
def get_session(
tenant_id: str = Depends(get_current_tenant_id),
) -> Generator[Session, None, None]:
"""Generate a database session with the appropriate tenant schema set."""
engine = get_sqlalchemy_engine()
with Session(engine, expire_on_commit=False) as session:
if MULTI_TENANT:
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Set the search_path to the tenant's schema
session.execute(text(f'SET search_path = "{tenant_id}"'))
yield session
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
async with AsyncSession(
get_sqlalchemy_async_engine(), expire_on_commit=False
) as async_session:
async def get_async_session(
tenant_id: str = Depends(get_current_tenant_id),
) -> AsyncGenerator[AsyncSession, None]:
"""Generate an async database session with the appropriate tenant schema set."""
engine = get_sqlalchemy_async_engine()
async with AsyncSession(engine, expire_on_commit=False) as async_session:
if MULTI_TENANT:
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Set the search_path to the tenant's schema
await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
yield async_session
def get_session_context_manager() -> ContextManager[Session]:
"""Context manager for database sessions."""
return contextlib.contextmanager(get_session)()
def get_session_factory() -> sessionmaker[Session]:
"""Get a session factory."""
global SessionFactory
if SessionFactory is None:
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
return SessionFactory
async def warm_up_connections(
sync_connections_to_warm_up: int = 20, async_connections_to_warm_up: int = 20
) -> None:
@@ -204,10 +355,3 @@ async def warm_up_connections(
await async_conn.execute(text("SELECT 1"))
for async_conn in async_connections:
await async_conn.close()
def get_session_factory() -> sessionmaker[Session]:
global SessionFactory
if SessionFactory is None:
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
return SessionFactory

View File

@@ -64,19 +64,12 @@ def upsert_cloud_embedding_provider(
def upsert_llm_provider(
llm_provider: LLMProviderUpsertRequest,
db_session: Session,
is_creation: bool = True,
) -> FullLLMProvider:
existing_llm_provider = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
)
if existing_llm_provider and is_creation:
raise ValueError(f"LLM Provider with name {llm_provider.name} already exists")
if not existing_llm_provider:
if not is_creation:
raise ValueError(
f"LLM Provider with name {llm_provider.name} does not exist"
)
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
db_session.add(existing_llm_provider)

View File

@@ -50,7 +50,7 @@ from danswer.db.enums import IndexingStatus
from danswer.db.enums import IndexModelStatus
from danswer.db.enums import TaskStatus
from danswer.db.pydantic_type import PydanticType
from danswer.dynamic_configs.interface import JSON_ro
from danswer.key_value_store.interface import JSON_ro
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
@@ -1725,7 +1725,9 @@ class User__ExternalUserGroupId(Base):
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
# These group ids have been prefixed by the source type
external_user_group_id: Mapped[str] = mapped_column(String, primary_key=True)
cc_pair_id: Mapped[int] = mapped_column(ForeignKey("connector_credential_pair.id"))
cc_pair_id: Mapped[int] = mapped_column(
ForeignKey("connector_credential_pair.id"), primary_key=True
)
class UsageReport(Base):

View File

@@ -11,7 +11,7 @@ from danswer.db.index_attempt import (
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.search_settings import update_search_settings_status
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.key_value_store.factory import get_kv_store
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -54,7 +54,7 @@ def check_index_swap(db_session: Session) -> None:
)
if cc_pair_count > 0:
kv_store = get_dynamic_config_store()
kv_store = get_kv_store()
kv_store.store(KV_REINDEX_KEY, False)
# Expire jobs for the now past index/embedding model

View File

@@ -1,3 +1,4 @@
from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import or_
@@ -107,12 +108,14 @@ def create_or_add_document_tag_list(
return all_tags
def get_tags_by_value_prefix_for_source_types(
def find_tags(
tag_key_prefix: str | None,
tag_value_prefix: str | None,
sources: list[DocumentSource] | None,
limit: int | None,
db_session: Session,
# if set, both tag_key_prefix and tag_value_prefix must be a match
require_both_to_match: bool = False,
) -> list[Tag]:
query = select(Tag)
@@ -122,7 +125,11 @@ def get_tags_by_value_prefix_for_source_types(
conditions.append(Tag.tag_key.ilike(f"{tag_key_prefix}%"))
if tag_value_prefix:
conditions.append(Tag.tag_value.ilike(f"{tag_value_prefix}%"))
query = query.where(or_(*conditions))
final_prefix_condition = (
and_(*conditions) if require_both_to_match else or_(*conditions)
)
query = query.where(final_prefix_condition)
if sources:
query = query.where(Tag.source.in_(sources))

View File

@@ -1,3 +1,6 @@
from sqlalchemy.orm import Session
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.vespa.index import VespaIndex
@@ -13,3 +16,14 @@ def get_default_document_index(
return VespaIndex(
index_name=primary_index_name, secondary_index_name=secondary_index_name
)
def get_current_primary_default_document_index(db_session: Session) -> DocumentIndex:
"""
TODO: Use redis to cache this or something
"""
search_settings = get_current_search_settings(db_session)
return get_default_document_index(
primary_index_name=search_settings.index_name,
secondary_index_name=None,
)

View File

@@ -55,6 +55,21 @@ class DocumentMetadata:
from_ingestion_api: bool = False
@dataclass
class VespaDocumentFields:
"""
Specifies fields in Vespa for a document. Fields set to None will be ignored.
Perhaps we should name this in an implementation agnostic fashion, but it's more
understandable like this for now.
"""
# all other fields except these 4 will always be left alone by the update request
access: DocumentAccess | None = None
document_sets: set[str] | None = None
boost: float | None = None
hidden: bool | None = None
@dataclass
class UpdateRequest:
"""
@@ -156,6 +171,16 @@ class Deletable(abc.ABC):
Class must implement the ability to delete document by their unique document ids.
"""
@abc.abstractmethod
def delete_single(self, doc_id: str) -> None:
"""
Given a single document id, hard delete it from the document index
Parameters:
- doc_id: document id as specified by the connector
"""
raise NotImplementedError
@abc.abstractmethod
def delete(self, doc_ids: list[str]) -> None:
"""
@@ -178,11 +203,9 @@ class Updatable(abc.ABC):
"""
@abc.abstractmethod
def update_single(self, update_request: UpdateRequest) -> None:
def update_single(self, doc_id: str, fields: VespaDocumentFields) -> None:
"""
Updates some set of chunks for a document. The document and fields to update
are specified in the update request. Each update request in the list applies
its changes to a list of document ids.
Updates all chunks for a document with the specified fields.
None values mean that the field does not need an update.
The rationale for a single update function is that it allows retries and parallelism
@@ -190,14 +213,10 @@ class Updatable(abc.ABC):
us to individually handle error conditions per document.
Parameters:
- update_request: for a list of document ids in the update request, apply the same updates
to all of the documents with those ids.
- fields: the fields to update in the document. Any field set to None will not be changed.
Return:
- an HTTPStatus code. The code can used to decide whether to fail immediately,
retry, etc. Although this method likely hits an HTTP API behind the
scenes, the usage of HTTPStatus is a convenience and the interface is not
actually HTTP specific.
None
"""
raise NotImplementedError

View File

@@ -1,5 +1,6 @@
import concurrent.futures
import io
import logging
import os
import re
import time
@@ -13,6 +14,7 @@ from typing import cast
import httpx
import requests
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.chat_configs import DOC_TIME_DECAY
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.chat_configs import TITLE_CONTENT_RATIO
@@ -22,6 +24,7 @@ from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import DocumentInsertionRecord
from danswer.document_index.interfaces import UpdateRequest
from danswer.document_index.interfaces import VespaChunkRequest
from danswer.document_index.interfaces import VespaDocumentFields
from danswer.document_index.vespa.chunk_retrieval import batch_search_api_retrieval
from danswer.document_index.vespa.chunk_retrieval import (
get_all_vespa_ids_for_document_id,
@@ -58,8 +61,8 @@ from danswer.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT
from danswer.document_index.vespa_constants import VESPA_DIM_REPLACEMENT_PAT
from danswer.document_index.vespa_constants import VESPA_TIMEOUT
from danswer.document_index.vespa_constants import YQL_BASE
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.key_value_store.factory import get_kv_store
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunkUncleaned
from danswer.utils.batching import batch_generator
@@ -68,6 +71,10 @@ from shared_configs.model_server_models import Embedding
logger = setup_logger()
# Set the logging level to WARNING to ignore INFO and DEBUG logs
httpx_logger = logging.getLogger("httpx")
httpx_logger.setLevel(logging.WARNING)
@dataclass
class _VespaUpdateRequest:
@@ -140,7 +147,7 @@ class VespaIndex(DocumentIndex):
SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS)
)
kv_store = get_dynamic_config_store()
kv_store = get_kv_store()
needs_reindexing = False
try:
@@ -377,89 +384,86 @@ class VespaIndex(DocumentIndex):
time.monotonic() - update_start,
)
def update_single(self, update_request: UpdateRequest) -> None:
def update_single(self, doc_id: str, fields: VespaDocumentFields) -> None:
"""Note: if the document id does not exist, the update will be a no-op and the
function will complete with no errors or exceptions.
Handle other exceptions if you wish to implement retry behavior
"""
if len(update_request.document_ids) != 1:
raise ValueError("update_request must contain a single document id")
# Handle Vespa character limitations
# Mutating update_request but it's not used later anyway
update_request.document_ids = [
replace_invalid_doc_id_characters(doc_id)
for doc_id in update_request.document_ids
]
# update_start = time.monotonic()
# Fetch all chunks for each document ahead of time
index_names = [self.index_name]
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
chunk_id_start_time = time.monotonic()
all_doc_chunk_ids: list[str] = []
for index_name in index_names:
for document_id in update_request.document_ids:
# this calls vespa and can raise http exceptions
doc_chunk_ids = get_all_vespa_ids_for_document_id(
document_id=document_id,
index_name=index_name,
filters=None,
get_large_chunks=True,
)
all_doc_chunk_ids.extend(doc_chunk_ids)
logger.debug(
f"Took {time.monotonic() - chunk_id_start_time:.2f} seconds to fetch all Vespa chunk IDs"
)
normalized_doc_id = replace_invalid_doc_id_characters(doc_id)
# Build the _VespaUpdateRequest objects
update_dict: dict[str, dict] = {"fields": {}}
if update_request.boost is not None:
update_dict["fields"][BOOST] = {"assign": update_request.boost}
if update_request.document_sets is not None:
if fields.boost is not None:
update_dict["fields"][BOOST] = {"assign": fields.boost}
if fields.document_sets is not None:
update_dict["fields"][DOCUMENT_SETS] = {
"assign": {
document_set: 1 for document_set in update_request.document_sets
}
"assign": {document_set: 1 for document_set in fields.document_sets}
}
if update_request.access is not None:
if fields.access is not None:
update_dict["fields"][ACCESS_CONTROL_LIST] = {
"assign": {acl_entry: 1 for acl_entry in update_request.access.to_acl()}
"assign": {acl_entry: 1 for acl_entry in fields.access.to_acl()}
}
if update_request.hidden is not None:
update_dict["fields"][HIDDEN] = {"assign": update_request.hidden}
if fields.hidden is not None:
update_dict["fields"][HIDDEN] = {"assign": fields.hidden}
if not update_dict["fields"]:
logger.error("Update request received but nothing to update")
return
processed_update_requests: list[_VespaUpdateRequest] = []
for document_id in update_request.document_ids:
for doc_chunk_id in all_doc_chunk_ids:
processed_update_requests.append(
_VespaUpdateRequest(
document_id=document_id,
url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}/{doc_chunk_id}",
update_request=update_dict,
)
)
index_names = [self.index_name]
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
with httpx.Client(http2=True) as http_client:
for update in processed_update_requests:
http_client.put(
update.url,
headers={"Content-Type": "application/json"},
json=update.update_request,
for index_name in index_names:
params = httpx.QueryParams(
{
"selection": f"{index_name}.document_id=='{normalized_doc_id}'",
"cluster": DOCUMENT_INDEX_NAME,
}
)
# logger.debug(
# "Finished updating Vespa documents in %.2f seconds",
# time.monotonic() - update_start,
# )
total_chunks_updated = 0
while True:
try:
resp = http_client.put(
f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}",
params=params,
headers={"Content-Type": "application/json"},
json=update_dict,
)
resp.raise_for_status()
except httpx.HTTPStatusError as e:
logger.error(
f"Failed to update chunks, details: {e.response.text}"
)
raise
resp_data = resp.json()
if "documentCount" in resp_data:
chunks_updated = resp_data["documentCount"]
total_chunks_updated += chunks_updated
# Check for continuation token to handle pagination
if "continuation" not in resp_data:
break # Exit loop if no continuation token
if not resp_data["continuation"]:
break # Exit loop if continuation token is empty
params = params.set("continuation", resp_data["continuation"])
logger.debug(
f"VespaIndex.update_single: "
f"index={index_name} "
f"doc={normalized_doc_id} "
f"chunks_deleted={total_chunks_updated}"
)
return
def delete(self, doc_ids: list[str]) -> None:
@@ -478,6 +482,68 @@ class VespaIndex(DocumentIndex):
delete_vespa_docs(
document_ids=doc_ids, index_name=index_name, http_client=http_client
)
return
def delete_single(self, doc_id: str) -> None:
"""Possibly faster overall than the delete method due to using a single
delete call with a selection query."""
# Vespa deletion is poorly documented ... luckily we found this
# https://docs.vespa.ai/en/operations/batch-delete.html#example
doc_id = replace_invalid_doc_id_characters(doc_id)
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficial for
# indexing / updates / deletes since we have to make a large volume of requests.
index_names = [self.index_name]
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
with httpx.Client(http2=True) as http_client:
for index_name in index_names:
params = httpx.QueryParams(
{
"selection": f"{index_name}.document_id=='{doc_id}'",
"cluster": DOCUMENT_INDEX_NAME,
}
)
total_chunks_deleted = 0
while True:
try:
resp = http_client.delete(
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}",
params=params,
)
resp.raise_for_status()
except httpx.HTTPStatusError as e:
logger.error(
f"Failed to delete chunk, details: {e.response.text}"
)
raise
resp_data = resp.json()
if "documentCount" in resp_data:
chunks_deleted = resp_data["documentCount"]
total_chunks_deleted += chunks_deleted
# Check for continuation token to handle pagination
if "continuation" not in resp_data:
break # Exit loop if no continuation token
if not resp_data["continuation"]:
break # Exit loop if continuation token is empty
params = params.set("continuation", resp_data["continuation"])
logger.debug(
f"VespaIndex.delete_single: "
f"index={index_name} "
f"doc={doc_id} "
f"chunks_deleted={total_chunks_deleted}"
)
return
def id_based_retrieval(
self,

View File

@@ -1,15 +0,0 @@
from danswer.configs.app_configs import DYNAMIC_CONFIG_STORE
from danswer.dynamic_configs.interface import DynamicConfigStore
from danswer.dynamic_configs.store import FileSystemBackedDynamicConfigStore
from danswer.dynamic_configs.store import PostgresBackedDynamicConfigStore
def get_dynamic_config_store() -> DynamicConfigStore:
dynamic_config_store_type = DYNAMIC_CONFIG_STORE
if dynamic_config_store_type == FileSystemBackedDynamicConfigStore.__name__:
raise NotImplementedError("File based config store no longer supported")
if dynamic_config_store_type == PostgresBackedDynamicConfigStore.__name__:
return PostgresBackedDynamicConfigStore()
# TODO: change exception type
raise Exception("Unknown dynamic config store type")

View File

@@ -1,102 +0,0 @@
import json
import os
from collections.abc import Iterator
from contextlib import contextmanager
from pathlib import Path
from typing import cast
from filelock import FileLock
from sqlalchemy.orm import Session
from danswer.db.engine import get_session_factory
from danswer.db.models import KVStore
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.dynamic_configs.interface import DynamicConfigStore
from danswer.dynamic_configs.interface import JSON_ro
FILE_LOCK_TIMEOUT = 10
def _get_file_lock(file_name: Path) -> FileLock:
return FileLock(file_name.with_suffix(".lock"))
class FileSystemBackedDynamicConfigStore(DynamicConfigStore):
def __init__(self, dir_path: str) -> None:
# TODO (chris): maybe require all possible keys to be passed in
# at app start somehow to prevent key overlaps
self.dir_path = Path(dir_path)
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
file_path = self.dir_path / key
lock = _get_file_lock(file_path)
with lock.acquire(timeout=FILE_LOCK_TIMEOUT):
with open(file_path, "w+") as f:
json.dump(val, f)
def load(self, key: str) -> JSON_ro:
file_path = self.dir_path / key
if not file_path.exists():
raise ConfigNotFoundError
lock = _get_file_lock(file_path)
with lock.acquire(timeout=FILE_LOCK_TIMEOUT):
with open(self.dir_path / key) as f:
return cast(JSON_ro, json.load(f))
def delete(self, key: str) -> None:
file_path = self.dir_path / key
if not file_path.exists():
raise ConfigNotFoundError
lock = _get_file_lock(file_path)
with lock.acquire(timeout=FILE_LOCK_TIMEOUT):
os.remove(file_path)
class PostgresBackedDynamicConfigStore(DynamicConfigStore):
@contextmanager
def get_session(self) -> Iterator[Session]:
factory = get_session_factory()
session: Session = factory()
try:
yield session
finally:
session.close()
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
# The actual encryption/decryption is done in Postgres, we just need to choose
# which field to set
encrypted_val = val if encrypt else None
plain_val = val if not encrypt else None
with self.get_session() as session:
obj = session.query(KVStore).filter_by(key=key).first()
if obj:
obj.value = plain_val
obj.encrypted_value = encrypted_val
else:
obj = KVStore(
key=key, value=plain_val, encrypted_value=encrypted_val
) # type: ignore
session.query(KVStore).filter_by(key=key).delete() # just in case
session.add(obj)
session.commit()
def load(self, key: str) -> JSON_ro:
with self.get_session() as session:
obj = session.query(KVStore).filter_by(key=key).first()
if not obj:
raise ConfigNotFoundError
if obj.value is not None:
return cast(JSON_ro, obj.value)
if obj.encrypted_value is not None:
return cast(JSON_ro, obj.encrypted_value)
return None
def delete(self, key: str) -> None:
with self.get_session() as session:
result = session.query(KVStore).filter_by(key=key).delete() # type: ignore
if result == 0:
raise ConfigNotFoundError
session.commit()

View File

@@ -20,6 +20,8 @@ from pypdf.errors import PdfStreamError
from danswer.configs.constants import DANSWER_METADATA_FILENAME
from danswer.file_processing.html_utils import parse_html_page_basic
from danswer.file_processing.unstructured import get_unstructured_api_key
from danswer.file_processing.unstructured import unstructured_to_text
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -331,9 +333,10 @@ def file_io_to_text(file: IO[Any]) -> str:
def extract_file_text(
file_name: str | None,
file: IO[Any],
file_name: str,
break_on_unprocessable: bool = True,
extension: str | None = None,
) -> str:
extension_to_function: dict[str, Callable[[IO[Any]], str]] = {
".pdf": pdf_to_text,
@@ -345,22 +348,29 @@ def extract_file_text(
".html": parse_html_page_basic,
}
def _process_file() -> str:
if file_name:
extension = get_file_ext(file_name)
if check_file_ext_is_valid(extension):
return extension_to_function.get(extension, file_io_to_text)(file)
try:
if get_unstructured_api_key():
return unstructured_to_text(file, file_name)
# Either the file somehow has no name or the extension is not one that we are familiar with
if file_name or extension:
if extension is not None:
final_extension = extension
elif file_name is not None:
final_extension = get_file_ext(file_name)
if check_file_ext_is_valid(final_extension):
return extension_to_function.get(final_extension, file_io_to_text)(file)
# Either the file somehow has no name or the extension is not one that we recognize
if is_text_file(file):
return file_io_to_text(file)
raise ValueError("Unknown file extension and unknown text encoding")
try:
return _process_file()
except Exception as e:
if break_on_unprocessable:
raise RuntimeError(f"Failed to process file: {str(e)}") from e
logger.warning(f"Failed to process file: {str(e)}")
raise RuntimeError(
f"Failed to process file {file_name or 'Unknown'}: {str(e)}"
) from e
logger.warning(f"Failed to process file {file_name or 'Unknown'}: {str(e)}")
return ""

View File

@@ -0,0 +1,67 @@
from typing import Any
from typing import cast
from typing import IO
from unstructured.staging.base import dict_to_elements
from unstructured_client import UnstructuredClient # type: ignore
from unstructured_client.models import operations # type: ignore
from unstructured_client.models import shared
from danswer.configs.constants import KV_UNSTRUCTURED_API_KEY
from danswer.key_value_store.factory import get_kv_store
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_unstructured_api_key() -> str | None:
kv_store = get_kv_store()
try:
return cast(str, kv_store.load(KV_UNSTRUCTURED_API_KEY))
except KvKeyNotFoundError:
return None
def update_unstructured_api_key(api_key: str) -> None:
kv_store = get_kv_store()
kv_store.store(KV_UNSTRUCTURED_API_KEY, api_key)
def delete_unstructured_api_key() -> None:
kv_store = get_kv_store()
kv_store.delete(KV_UNSTRUCTURED_API_KEY)
def _sdk_partition_request(
file: IO[Any], file_name: str, **kwargs: Any
) -> operations.PartitionRequest:
try:
request = operations.PartitionRequest(
partition_parameters=shared.PartitionParameters(
files=shared.Files(content=file.read(), file_name=file_name),
**kwargs,
),
)
return request
except Exception as e:
logger.error(f"Error creating partition request for file {file_name}: {str(e)}")
raise
def unstructured_to_text(file: IO[Any], file_name: str) -> str:
logger.debug(f"Starting to read file: {file_name}")
req = _sdk_partition_request(file, file_name, strategy="auto")
unstructured_client = UnstructuredClient(api_key_auth=get_unstructured_api_key())
response = unstructured_client.general.partition(req) # type: ignore
elements = dict_to_elements(response.elements)
if response.status_code != 200:
err = f"Received unexpected status code {response.status_code} from Unstructured API."
logger.error(err)
raise ValueError(err)
return "\n\n".join(str(el) for el in elements)

View File

@@ -10,6 +10,7 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
get_metadata_keys_to_ignore,
)
from danswer.connectors.models import Document
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.models import DocAwareChunk
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.utils.logger import setup_logger
@@ -123,6 +124,7 @@ class Chunker:
chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE,
chunk_overlap: int = CHUNK_OVERLAP,
mini_chunk_size: int = MINI_CHUNK_SIZE,
heartbeat: Heartbeat | None = None,
) -> None:
from llama_index.text_splitter import SentenceSplitter
@@ -131,6 +133,7 @@ class Chunker:
self.enable_multipass = enable_multipass
self.enable_large_chunks = enable_large_chunks
self.tokenizer = tokenizer
self.heartbeat = heartbeat
self.blurb_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize,
@@ -255,7 +258,7 @@ class Chunker:
# If the chunk does not have any useable content, it will not be indexed
return chunks
def chunk(self, document: Document) -> list[DocAwareChunk]:
def _handle_single_document(self, document: Document) -> list[DocAwareChunk]:
# Specifically for reproducing an issue with gmail
if document.source == DocumentSource.GMAIL:
logger.debug(f"Chunking {document.semantic_identifier}")
@@ -302,3 +305,13 @@ class Chunker:
normal_chunks.extend(large_chunks)
return normal_chunks
def chunk(self, documents: list[Document]) -> list[DocAwareChunk]:
final_chunks: list[DocAwareChunk] = []
for document in documents:
final_chunks.extend(self._handle_single_document(document))
if self.heartbeat:
self.heartbeat.heartbeat()
return final_chunks

View File

@@ -1,12 +1,8 @@
from abc import ABC
from abc import abstractmethod
from sqlalchemy.orm import Session
from danswer.db.models import IndexModelStatus
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
@@ -24,6 +20,9 @@ logger = setup_logger()
class IndexingEmbedder(ABC):
"""Converts chunks into chunks with embeddings. Note that one chunk may have
multiple embeddings associated with it."""
def __init__(
self,
model_name: str,
@@ -33,6 +32,7 @@ class IndexingEmbedder(ABC):
provider_type: EmbeddingProvider | None,
api_key: str | None,
api_url: str | None,
heartbeat: Heartbeat | None,
):
self.model_name = model_name
self.normalize = normalize
@@ -54,6 +54,7 @@ class IndexingEmbedder(ABC):
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
retrim_content=True,
heartbeat=heartbeat,
)
@abstractmethod
@@ -74,6 +75,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type: EmbeddingProvider | None = None,
api_key: str | None = None,
api_url: str | None = None,
heartbeat: Heartbeat | None = None,
):
super().__init__(
model_name,
@@ -83,6 +85,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type,
api_key,
api_url,
heartbeat,
)
@log_function_time()
@@ -166,7 +169,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
title_embed_dict[title] = title_embedding
new_embedded_chunk = IndexChunk(
**chunk.dict(),
**chunk.model_dump(),
embeddings=ChunkEmbedding(
full_embedding=chunk_embeddings[0],
mini_chunk_embeddings=chunk_embeddings[1:],
@@ -180,7 +183,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
@classmethod
def from_db_search_settings(
cls, search_settings: SearchSettings
cls, search_settings: SearchSettings, heartbeat: Heartbeat | None = None
) -> "DefaultIndexingEmbedder":
return cls(
model_name=search_settings.model_name,
@@ -190,28 +193,5 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type=search_settings.provider_type,
api_key=search_settings.api_key,
api_url=search_settings.api_url,
heartbeat=heartbeat,
)
def get_embedding_model_from_search_settings(
db_session: Session, index_model_status: IndexModelStatus = IndexModelStatus.PRESENT
) -> IndexingEmbedder:
search_settings: SearchSettings | None
if index_model_status == IndexModelStatus.PRESENT:
search_settings = get_current_search_settings(db_session)
elif index_model_status == IndexModelStatus.FUTURE:
search_settings = get_secondary_search_settings(db_session)
if not search_settings:
raise RuntimeError("No secondary index configured")
else:
raise RuntimeError("Not supporting embedding model rollbacks")
return DefaultIndexingEmbedder(
model_name=search_settings.model_name,
normalize=search_settings.normalize,
query_prefix=search_settings.query_prefix,
passage_prefix=search_settings.passage_prefix,
provider_type=search_settings.provider_type,
api_key=search_settings.api_key,
api_url=search_settings.api_url,
)

View File

@@ -0,0 +1,41 @@
import abc
from typing import Any
from sqlalchemy import func
from sqlalchemy.orm import Session
from danswer.db.index_attempt import get_index_attempt
from danswer.utils.logger import setup_logger
logger = setup_logger()
class Heartbeat(abc.ABC):
"""Useful for any long-running work that goes through a bunch of items
and needs to occasionally give updates on progress.
e.g. chunking, embedding, updating vespa, etc."""
@abc.abstractmethod
def heartbeat(self, metadata: Any = None) -> None:
raise NotImplementedError
class IndexingHeartbeat(Heartbeat):
def __init__(self, index_attempt_id: int, db_session: Session, freq: int):
self.cnt = 0
self.index_attempt_id = index_attempt_id
self.db_session = db_session
self.freq = freq
def heartbeat(self, metadata: Any = None) -> None:
self.cnt += 1
if self.cnt % self.freq == 0:
index_attempt = get_index_attempt(
db_session=self.db_session, index_attempt_id=self.index_attempt_id
)
if index_attempt:
index_attempt.time_updated = func.now()
self.db_session.commit()
else:
logger.error("Index attempt not found, this should not happen!")

View File

@@ -31,6 +31,7 @@ from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import DocumentMetadata
from danswer.indexing.chunker import Chunker
from danswer.indexing.embedder import IndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.utils.logger import setup_logger
@@ -220,8 +221,8 @@ def index_doc_batch_prepare(
document_ids = [document.id for document in documents]
db_docs: list[DBDocument] = get_documents_by_ids(
document_ids=document_ids,
db_session=db_session,
document_ids=document_ids,
)
# Skip indexing docs that don't have a newer updated at
@@ -283,18 +284,10 @@ def index_doc_batch(
return 0, 0
logger.debug("Starting chunking")
chunks: list[DocAwareChunk] = []
for document in ctx.updatable_docs:
chunks.extend(chunker.chunk(document=document))
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
logger.debug("Starting embedding")
chunks_with_embeddings = (
embedder.embed_chunks(
chunks=chunks,
)
if chunks
else []
)
chunks_with_embeddings = embedder.embed_chunks(chunks) if chunks else []
updatable_ids = [doc.id for doc in ctx.updatable_docs]
@@ -406,6 +399,13 @@ def build_indexing_pipeline(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=multipass,
enable_large_chunks=enable_large_chunks,
# after every doc, update status in case there are a bunch of
# really long docs
heartbeat=IndexingHeartbeat(
index_attempt_id=attempt_id, db_session=db_session, freq=1
)
if attempt_id
else None,
)
return partial(

View File

@@ -0,0 +1,7 @@
from danswer.key_value_store.interface import KeyValueStore
from danswer.key_value_store.store import PgRedisKVStore
def get_kv_store() -> KeyValueStore:
# this is the only one supported currently
return PgRedisKVStore()

View File

@@ -9,11 +9,11 @@ JSON_ro: TypeAlias = (
)
class ConfigNotFoundError(Exception):
class KvKeyNotFoundError(Exception):
pass
class DynamicConfigStore:
class KeyValueStore:
@abc.abstractmethod
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
raise NotImplementedError

View File

@@ -0,0 +1,99 @@
import json
from collections.abc import Iterator
from contextlib import contextmanager
from typing import cast
from sqlalchemy.orm import Session
from danswer.db.engine import get_session_factory
from danswer.db.models import KVStore
from danswer.key_value_store.interface import JSON_ro
from danswer.key_value_store.interface import KeyValueStore
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
logger = setup_logger()
REDIS_KEY_PREFIX = "danswer_kv_store:"
KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
class PgRedisKVStore(KeyValueStore):
def __init__(self) -> None:
self.redis_client = get_redis_client()
@contextmanager
def get_session(self) -> Iterator[Session]:
factory = get_session_factory()
session: Session = factory()
try:
yield session
finally:
session.close()
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
# Not encrypted in Redis, but encrypted in Postgres
try:
self.redis_client.set(
REDIS_KEY_PREFIX + key, json.dumps(val), ex=KV_REDIS_KEY_EXPIRATION
)
except Exception as e:
# Fallback gracefully to Postgres if Redis fails
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
encrypted_val = val if encrypt else None
plain_val = val if not encrypt else None
with self.get_session() as session:
obj = session.query(KVStore).filter_by(key=key).first()
if obj:
obj.value = plain_val
obj.encrypted_value = encrypted_val
else:
obj = KVStore(
key=key, value=plain_val, encrypted_value=encrypted_val
) # type: ignore
session.query(KVStore).filter_by(key=key).delete() # just in case
session.add(obj)
session.commit()
def load(self, key: str) -> JSON_ro:
try:
redis_value = self.redis_client.get(REDIS_KEY_PREFIX + key)
if redis_value:
assert isinstance(redis_value, bytes)
return json.loads(redis_value.decode("utf-8"))
except Exception as e:
logger.error(f"Failed to get value from Redis for key '{key}': {str(e)}")
with self.get_session() as session:
obj = session.query(KVStore).filter_by(key=key).first()
if not obj:
raise KvKeyNotFoundError
if obj.value is not None:
value = obj.value
elif obj.encrypted_value is not None:
value = obj.encrypted_value
else:
value = None
try:
self.redis_client.set(REDIS_KEY_PREFIX + key, json.dumps(value))
except Exception as e:
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
return cast(JSON_ro, value)
def delete(self, key: str) -> None:
try:
self.redis_client.delete(REDIS_KEY_PREFIX + key)
except Exception as e:
logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}")
with self.get_session() as session:
result = session.query(KVStore).filter_by(key=key).delete() # type: ignore
if result == 0:
raise KvKeyNotFoundError
session.commit()

View File

@@ -1,3 +1,4 @@
import itertools
from collections.abc import Callable
from collections.abc import Iterator
from typing import Any
@@ -554,8 +555,7 @@ class Answer:
def _stream() -> Iterator[str]:
nonlocal stream_stop_info
yield cast(str, message)
for item in stream:
for item in itertools.chain([message], stream):
if isinstance(item, StreamStopInfo):
stream_stop_info = item
return

View File

@@ -290,10 +290,12 @@ class DefaultMultiLLM(LLM):
return litellm.completion(
# model choice
model=f"{self.config.model_provider}/{self.config.model_name}",
api_key=self._api_key,
base_url=self._api_base,
api_version=self._api_version,
custom_llm_provider=self._custom_llm_provider,
# NOTE: have to pass in None instead of empty string for these
# otherwise litellm can have some issues with bedrock
api_key=self._api_key or None,
base_url=self._api_base or None,
api_version=self._api_version or None,
custom_llm_provider=self._custom_llm_provider or None,
# actual input
messages=prompt,
tools=tools,

View File

@@ -51,6 +51,7 @@ from danswer.db.credentials import create_initial_public_credential
from danswer.db.document import check_docs_exist
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import init_sqlalchemy_engine
from danswer.db.engine import warm_up_connections
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import expire_index_attempts
from danswer.db.llm import fetch_default_provider
@@ -64,9 +65,9 @@ from danswer.db.search_settings import update_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import DocumentIndex
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.indexing.models import IndexingSetting
from danswer.key_value_store.factory import get_kv_store
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
@@ -255,7 +256,7 @@ def update_default_multipass_indexing(db_session: Session) -> None:
def translate_saved_search_settings(db_session: Session) -> None:
kv_store = get_dynamic_config_store()
kv_store = get_kv_store()
try:
search_settings_dict = kv_store.load(KV_SEARCH_SETTINGS)
@@ -293,17 +294,17 @@ def translate_saved_search_settings(db_session: Session) -> None:
logger.notice("Search settings updated and KV store entry deleted.")
else:
logger.notice("KV store search settings is empty.")
except ConfigNotFoundError:
except KvKeyNotFoundError:
logger.notice("No search config found in KV store.")
def mark_reindex_flag(db_session: Session) -> None:
kv_store = get_dynamic_config_store()
kv_store = get_kv_store()
try:
value = kv_store.load(KV_REINDEX_KEY)
logger.debug(f"Re-indexing flag has value {value}")
return
except ConfigNotFoundError:
except KvKeyNotFoundError:
# Only need to update the flag if it hasn't been set
pass
@@ -368,7 +369,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
logger.notice("Generative AI Q&A disabled")
# fill up Postgres connection pools
# await warm_up_connections()
await warm_up_connections()
# We cache this at the beginning so there is no delay in the first telemetry
get_or_generate_uuid()

View File

@@ -16,6 +16,7 @@ from danswer.configs.model_configs import (
)
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.models import SearchSettings
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.logger import setup_logger
@@ -95,6 +96,7 @@ class EmbeddingModel:
api_url: str | None,
provider_type: EmbeddingProvider | None,
retrim_content: bool = False,
heartbeat: Heartbeat | None = None,
) -> None:
self.api_key = api_key
self.provider_type = provider_type
@@ -107,6 +109,7 @@ class EmbeddingModel:
self.tokenizer = get_tokenizer(
model_name=model_name, provider_type=provider_type
)
self.heartbeat = heartbeat
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
@@ -166,6 +169,9 @@ class EmbeddingModel:
response = self._make_model_server_request(embed_request)
embeddings.extend(response.embeddings)
if self.heartbeat:
self.heartbeat.heartbeat()
return embeddings
def encode(

View File

@@ -3,23 +3,23 @@ from typing import Optional
import redis
from redis.client import Redis
from redis.connection import ConnectionPool
from danswer.configs.app_configs import REDIS_DB_NUMBER
from danswer.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
from danswer.configs.app_configs import REDIS_HOST
from danswer.configs.app_configs import REDIS_PASSWORD
from danswer.configs.app_configs import REDIS_POOL_MAX_CONNECTIONS
from danswer.configs.app_configs import REDIS_PORT
from danswer.configs.app_configs import REDIS_SSL
from danswer.configs.app_configs import REDIS_SSL_CA_CERTS
from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
REDIS_POOL_MAX_CONNECTIONS = 10
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
class RedisPool:
_instance: Optional["RedisPool"] = None
_lock: threading.Lock = threading.Lock()
_pool: ConnectionPool
_pool: redis.BlockingConnectionPool
def __new__(cls) -> "RedisPool":
if not cls._instance:
@@ -42,33 +42,52 @@ class RedisPool:
db: int = REDIS_DB_NUMBER,
password: str = REDIS_PASSWORD,
max_connections: int = REDIS_POOL_MAX_CONNECTIONS,
ssl_ca_certs: str = REDIS_SSL_CA_CERTS,
ssl_ca_certs: str | None = REDIS_SSL_CA_CERTS,
ssl_cert_reqs: str = REDIS_SSL_CERT_REQS,
ssl: bool = False,
) -> redis.ConnectionPool:
) -> redis.BlockingConnectionPool:
"""We use BlockingConnectionPool because it will block and wait for a connection
rather than error if max_connections is reached. This is far more deterministic
behavior and aligned with how we want to use Redis."""
# Using ConnectionPool is not well documented.
# Useful examples: https://github.com/redis/redis-py/issues/780
if ssl:
return redis.ConnectionPool(
return redis.BlockingConnectionPool(
host=host,
port=port,
db=db,
password=password,
max_connections=max_connections,
timeout=None,
health_check_interval=REDIS_HEALTH_CHECK_INTERVAL,
socket_keepalive=True,
socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS,
connection_class=redis.SSLConnection,
ssl_ca_certs=ssl_ca_certs,
ssl_cert_reqs=ssl_cert_reqs,
)
return redis.ConnectionPool(
return redis.BlockingConnectionPool(
host=host,
port=port,
db=db,
password=password,
max_connections=max_connections,
timeout=None,
health_check_interval=REDIS_HEALTH_CHECK_INTERVAL,
socket_keepalive=True,
socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS,
)
redis_pool = RedisPool()
def get_redis_client() -> Redis:
return redis_pool.get_client()
# # Usage example
# redis_pool = RedisPool()
# redis_client = redis_pool.get_client()

View File

@@ -1,8 +1,8 @@
from typing import cast
from danswer.configs.constants import KV_SEARCH_SETTINGS
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.key_value_store.factory import get_kv_store
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.search.models import SavedSearchSettings
from danswer.utils.logger import setup_logger
@@ -17,10 +17,10 @@ def get_kv_search_settings() -> SavedSearchSettings | None:
if the value is updated by another process/instance of the API server. If this reads from an in memory cache like
reddis then it will be ok. Until then this has some performance implications (though minor)
"""
kv_store = get_dynamic_config_store()
kv_store = get_kv_store()
try:
return SavedSearchSettings(**cast(dict, kv_store.load(KV_SEARCH_SETTINGS)))
except ConfigNotFoundError:
except KvKeyNotFoundError:
return None
except Exception as e:
logger.error(f"Error loading search settings: {e}")

View File

@@ -1,4 +1,5 @@
import math
from http import HTTPStatus
from fastapi import APIRouter
from fastapi import Depends
@@ -10,6 +11,8 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
from danswer.background.celery.celery_utils import skip_cc_pair_pruning_by_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.db.connector_credential_pair import add_credential_to_connector
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.connector_credential_pair import remove_credential_from_connector
@@ -26,17 +29,22 @@ from danswer.db.index_attempt import count_index_attempts_for_connector
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
from danswer.db.models import User
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.server.documents.models import CCPairFullInfo
from danswer.server.documents.models import CCStatusUpdateRequest
from danswer.server.documents.models import CeleryTaskStatus
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.server.documents.models import ConnectorCredentialPairMetadata
from danswer.server.documents.models import PaginatedIndexAttempts
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from ee.danswer.background.task_name_builders import (
name_sync_external_doc_permissions_task,
)
from ee.danswer.db.user_group import validate_user_creation_permissions
logger = setup_logger()
router = APIRouter(prefix="/manage")
@@ -190,6 +198,181 @@ def update_cc_pair_name(
raise HTTPException(status_code=400, detail="Name must be unique")
@router.get("/admin/cc-pair/{cc_pair_id}/prune")
def get_cc_pair_latest_prune(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> CeleryTaskStatus:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="Connection not found for current user's permissions",
)
# look up the last prune task for this connector (if it exists)
pruning_task_name = name_cc_prune_task(
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
)
last_pruning_task = get_latest_task(pruning_task_name, db_session)
if not last_pruning_task:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="No pruning task found.",
)
return CeleryTaskStatus(
id=last_pruning_task.task_id,
name=last_pruning_task.task_name,
status=last_pruning_task.status,
start_time=last_pruning_task.start_time,
register_time=last_pruning_task.register_time,
)
@router.post("/admin/cc-pair/{cc_pair_id}/prune")
def prune_cc_pair(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> StatusResponse[list[int]]:
# avoiding circular refs
from danswer.background.celery.tasks.pruning.tasks import prune_documents_task
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="Connection not found for current user's permissions",
)
pruning_task_name = name_cc_prune_task(
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
)
last_pruning_task = get_latest_task(pruning_task_name, db_session)
if skip_cc_pair_pruning_by_task(
last_pruning_task,
db_session=db_session,
):
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail="Pruning task already in progress.",
)
logger.info(f"Pruning the {cc_pair.connector.name} connector.")
prune_documents_task.apply_async(
kwargs=dict(
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
)
return StatusResponse(
success=True,
message="Successfully created the pruning task.",
)
@router.get("/admin/cc-pair/{cc_pair_id}/sync")
def get_cc_pair_latest_sync(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> CeleryTaskStatus:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="Connection not found for current user's permissions",
)
# look up the last sync task for this connector (if it exists)
sync_task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair_id)
last_sync_task = get_latest_task(sync_task_name, db_session)
if not last_sync_task:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="No sync task found.",
)
return CeleryTaskStatus(
id=last_sync_task.task_id,
name=last_sync_task.task_name,
status=last_sync_task.status,
start_time=last_sync_task.start_time,
register_time=last_sync_task.register_time,
)
@router.post("/admin/cc-pair/{cc_pair_id}/sync")
def sync_cc_pair(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> StatusResponse[list[int]]:
# avoiding circular refs
from ee.danswer.background.celery.celery_app import (
sync_external_doc_permissions_task,
)
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="Connection not found for current user's permissions",
)
sync_task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair_id)
last_sync_task = get_latest_task(sync_task_name, db_session)
if last_sync_task and check_task_is_live_and_not_timed_out(
last_sync_task, db_session
):
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail="Sync task already in progress.",
)
if skip_cc_pair_pruning_by_task(
last_sync_task,
db_session=db_session,
):
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail="Sync task already in progress.",
)
logger.info(f"Syncing the {cc_pair.connector.name} connector.")
sync_external_doc_permissions_task.apply_async(
kwargs=dict(cc_pair_id=cc_pair_id),
)
return StatusResponse(
success=True,
message="Successfully created the sync task.",
)
@router.put("/connector/{connector_id}/credential/{credential_id}")
def associate_credential_to_connector(
connector_id: int,

View File

@@ -74,8 +74,8 @@ from danswer.db.models import IndexingStatus
from danswer.db.models import User
from danswer.db.models import UserRole
from danswer.db.search_settings import get_current_search_settings
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.file_store.file_store import get_default_file_store
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.server.documents.models import AuthStatus
from danswer.server.documents.models import AuthUrl
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
@@ -116,7 +116,7 @@ def check_google_app_gmail_credentials_exist(
) -> dict[str, str]:
try:
return {"client_id": get_google_app_gmail_cred().web.client_id}
except ConfigNotFoundError:
except KvKeyNotFoundError:
raise HTTPException(status_code=404, detail="Google App Credentials not found")
@@ -140,7 +140,7 @@ def delete_google_app_gmail_credentials(
) -> StatusResponse:
try:
delete_google_app_gmail_cred()
except ConfigNotFoundError as e:
except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
return StatusResponse(
@@ -154,7 +154,7 @@ def check_google_app_credentials_exist(
) -> dict[str, str]:
try:
return {"client_id": get_google_app_cred().web.client_id}
except ConfigNotFoundError:
except KvKeyNotFoundError:
raise HTTPException(status_code=404, detail="Google App Credentials not found")
@@ -178,7 +178,7 @@ def delete_google_app_credentials(
) -> StatusResponse:
try:
delete_google_app_cred()
except ConfigNotFoundError as e:
except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
return StatusResponse(
@@ -192,7 +192,7 @@ def check_google_service_gmail_account_key_exist(
) -> dict[str, str]:
try:
return {"service_account_email": get_gmail_service_account_key().client_email}
except ConfigNotFoundError:
except KvKeyNotFoundError:
raise HTTPException(
status_code=404, detail="Google Service Account Key not found"
)
@@ -218,7 +218,7 @@ def delete_google_service_gmail_account_key(
) -> StatusResponse:
try:
delete_gmail_service_account_key()
except ConfigNotFoundError as e:
except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
return StatusResponse(
@@ -232,7 +232,7 @@ def check_google_service_account_key_exist(
) -> dict[str, str]:
try:
return {"service_account_email": get_service_account_key().client_email}
except ConfigNotFoundError:
except KvKeyNotFoundError:
raise HTTPException(
status_code=404, detail="Google Service Account Key not found"
)
@@ -258,7 +258,7 @@ def delete_google_service_account_key(
) -> StatusResponse:
try:
delete_service_account_key()
except ConfigNotFoundError as e:
except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
return StatusResponse(
@@ -280,7 +280,7 @@ def upsert_service_account_credential(
DocumentSource.GOOGLE_DRIVE,
delegated_user_email=service_account_credential_request.google_drive_delegated_user,
)
except ConfigNotFoundError as e:
except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
# first delete all existing service account credentials
@@ -306,7 +306,7 @@ def upsert_gmail_service_account_credential(
DocumentSource.GMAIL,
delegated_user_email=service_account_credential_request.gmail_delegated_user,
)
except ConfigNotFoundError as e:
except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
# first delete all existing service account credentials
@@ -781,6 +781,7 @@ def connector_run_once(
detail="Connector has no valid credentials, cannot create index attempts.",
)
# Prevents index attempts for cc pairs that already have an index attempt currently running
skipped_credentials = [
credential_id
for credential_id in credential_ids
@@ -790,15 +791,15 @@ def connector_run_once(
credential_id=credential_id,
),
only_current=True,
disinclude_finished=True,
db_session=db_session,
disinclude_finished=True,
)
]
search_settings = get_current_search_settings(db_session)
connector_credential_pairs = [
get_connector_credential_pair(run_info.connector_id, credential_id, db_session)
get_connector_credential_pair(connector_id, credential_id, db_session)
for credential_id in credential_ids
if credential_id not in skipped_credentials
]

View File

@@ -268,6 +268,14 @@ class CCPairFullInfo(BaseModel):
)
class CeleryTaskStatus(BaseModel):
id: str
name: str
status: TaskStatus
start_time: datetime | None
register_time: datetime | None
class FailedConnectorIndexingStatus(BaseModel):
"""Simplified version of ConnectorIndexingStatus for failed indexing attempts"""

View File

@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.background.celery.celery_app import celery_app
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DocumentSource
@@ -28,9 +29,9 @@ from danswer.db.index_attempt import cancel_indexing_attempts_for_ccpair
from danswer.db.models import User
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.file_store.file_store import get_default_file_store
from danswer.key_value_store.factory import get_kv_store
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.llm.factory import get_default_llms
from danswer.llm.utils import test_llm
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
@@ -113,7 +114,7 @@ def validate_existing_genai_api_key(
_: User = Depends(current_admin_user),
) -> None:
# Only validate every so often
kv_store = get_dynamic_config_store()
kv_store = get_kv_store()
curr_time = datetime.now(tz=timezone.utc)
try:
last_check = datetime.fromtimestamp(
@@ -122,7 +123,7 @@ def validate_existing_genai_api_key(
check_freq_sec = timedelta(seconds=GENERATIVE_MODEL_ACCESS_CHECK_FREQ)
if curr_time - last_check < check_freq_sec:
return
except ConfigNotFoundError:
except KvKeyNotFoundError:
# First time checking the key, nothing unusual
pass
@@ -146,10 +147,6 @@ def create_deletion_attempt_for_connector_id(
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> None:
from danswer.background.celery.celery_app import (
check_for_connector_deletion_task,
)
connector_id = connector_credential_pair_identifier.connector_id
credential_id = connector_credential_pair_identifier.credential_id
@@ -193,8 +190,11 @@ def create_deletion_attempt_for_connector_id(
status=ConnectorCredentialPairStatus.DELETING,
)
# run the beat task to pick up this deletion early
check_for_connector_deletion_task.apply_async(
db_session.commit()
# run the beat task to pick up this deletion from the db immediately
celery_app.send_task(
"check_for_connector_deletion_task",
priority=DanswerCeleryPriority.HIGH,
)

View File

@@ -10,6 +10,7 @@ from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.db.engine import get_session
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.llm import fetch_provider
from danswer.db.llm import remove_llm_provider
from danswer.db.llm import update_default_provider
from danswer.db.llm import upsert_llm_provider
@@ -124,17 +125,26 @@ def list_llm_providers(
def put_llm_provider(
llm_provider: LLMProviderUpsertRequest,
is_creation: bool = Query(
True,
False,
description="True if updating an existing provider, False if creating a new one",
),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> FullLLMProvider:
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
# the result
existing_provider = fetch_provider(db_session, llm_provider.name)
if existing_provider and is_creation:
raise HTTPException(
status_code=400,
detail=f"LLM Provider with name {llm_provider.name} already exists",
)
try:
return upsert_llm_provider(
llm_provider=llm_provider,
db_session=db_session,
is_creation=is_creation,
)
except ValueError as e:
logger.exception("Failed to upsert LLM Provider")

View File

@@ -21,6 +21,9 @@ from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.search_settings import update_current_search_settings
from danswer.db.search_settings import update_search_settings_status
from danswer.document_index.factory import get_default_document_index
from danswer.file_processing.unstructured import delete_unstructured_api_key
from danswer.file_processing.unstructured import get_unstructured_api_key
from danswer.file_processing.unstructured import update_unstructured_api_key
from danswer.natural_language_processing.search_nlp_models import clean_model_name
from danswer.search.models import SavedSearchSettings
from danswer.search.models import SearchSettingsCreationRequest
@@ -30,7 +33,6 @@ from danswer.server.models import IdReturn
from danswer.utils.logger import setup_logger
from shared_configs.configs import ALT_INDEX_SUFFIX
router = APIRouter(prefix="/search-settings")
logger = setup_logger()
@@ -196,3 +198,27 @@ def update_saved_search_settings(
update_current_search_settings(
search_settings=search_settings, db_session=db_session
)
@router.get("/unstructured-api-key-set")
def unstructured_api_key_set(
_: User | None = Depends(current_admin_user),
) -> bool:
api_key = get_unstructured_api_key()
print(api_key)
return api_key is not None
@router.put("/upsert-unstructured-api-key")
def upsert_unstructured_api_key(
unstructured_api_key: str,
_: User | None = Depends(current_admin_user),
) -> None:
update_unstructured_api_key(unstructured_api_key)
@router.delete("/delete-unstructured-api-key")
def delete_unstructured_api_key_endpoint(
_: User | None = Depends(current_admin_user),
) -> None:
delete_unstructured_api_key()

View File

@@ -18,7 +18,7 @@ from danswer.db.slack_bot_config import fetch_slack_bot_configs
from danswer.db.slack_bot_config import insert_slack_bot_config
from danswer.db.slack_bot_config import remove_slack_bot_config
from danswer.db.slack_bot_config import update_slack_bot_config
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.server.manage.models import SlackBotConfig
from danswer.server.manage.models import SlackBotConfigCreationRequest
from danswer.server.manage.models import SlackBotTokens
@@ -212,5 +212,5 @@ def put_tokens(
def get_tokens(_: User | None = Depends(current_admin_user)) -> SlackBotTokens:
try:
return fetch_tokens()
except ConfigNotFoundError:
except KvKeyNotFoundError:
raise HTTPException(status_code=404, detail="No tokens found")

View File

@@ -38,7 +38,7 @@ from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.db.users import get_user_by_email
from danswer.db.users import list_users
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.key_value_store.factory import get_kv_store
from danswer.server.manage.models import AllUsersResponse
from danswer.server.manage.models import UserByEmail
from danswer.server.manage.models import UserInfo
@@ -367,7 +367,7 @@ def verify_user_logged_in(
# if auth type is disabled, return a dummy user with preferences from
# the key-value store
if AUTH_TYPE == AuthType.DISABLED:
store = get_dynamic_config_store()
store = get_kv_store()
return fetch_no_auth_user(store)
raise HTTPException(
@@ -405,7 +405,7 @@ def update_user_default_model(
) -> None:
if user is None:
if AUTH_TYPE == AuthType.DISABLED:
store = get_dynamic_config_store()
store = get_kv_store()
no_auth_user = fetch_no_auth_user(store)
no_auth_user.preferences.default_model = request.default_model
set_no_auth_user_preferences(store, no_auth_user.preferences)
@@ -433,7 +433,7 @@ def update_user_assistant_list(
) -> None:
if user is None:
if AUTH_TYPE == AuthType.DISABLED:
store = get_dynamic_config_store()
store = get_kv_store()
no_auth_user = fetch_no_auth_user(store)
no_auth_user.preferences.chosen_assistants = request.chosen_assistants
@@ -487,7 +487,7 @@ def update_user_assistant_visibility(
) -> None:
if user is None:
if AUTH_TYPE == AuthType.DISABLED:
store = get_dynamic_config_store()
store = get_kv_store()
no_auth_user = fetch_no_auth_user(store)
preferences = no_auth_user.preferences
updated_preferences = update_assistant_list(preferences, assistant_id, show)

View File

@@ -588,7 +588,10 @@ def upload_files_for_chat(
# if the file is a doc, extract text and store that so we don't need
# to re-extract it every time we send a message
if file_type == ChatFileType.DOC:
extracted_text = extract_file_text(file_name=file.filename, file=file.file)
extracted_text = extract_file_text(
file=file.file,
file_name=file.filename or "",
)
text_file_id = str(uuid.uuid4())
file_store.save_file(
file_name=text_file_id,

View File

@@ -18,7 +18,7 @@ from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
from danswer.db.tag import get_tags_by_value_prefix_for_source_types
from danswer.db.tag import find_tags
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.vespa.index import VespaIndex
from danswer.one_shot_answer.answer_question import stream_search_answer
@@ -99,12 +99,25 @@ def get_tags(
if not allow_prefix:
raise NotImplementedError("Cannot disable prefix match for now")
db_tags = get_tags_by_value_prefix_for_source_types(
tag_key_prefix=match_pattern,
tag_value_prefix=match_pattern,
key_prefix = match_pattern
value_prefix = match_pattern
require_both_to_match = False
# split on = to allow the user to type in "author=bob"
EQUAL_PAT = "="
if match_pattern and EQUAL_PAT in match_pattern:
split_pattern = match_pattern.split(EQUAL_PAT)
key_prefix = split_pattern[0]
value_prefix = EQUAL_PAT.join(split_pattern[1:])
require_both_to_match = True
db_tags = find_tags(
tag_key_prefix=key_prefix,
tag_value_prefix=value_prefix,
sources=sources,
limit=limit,
db_session=db_session,
require_both_to_match=require_both_to_match,
)
server_tags = [
SourceTag(

View File

@@ -19,8 +19,8 @@ from danswer.db.notification import dismiss_notification
from danswer.db.notification import get_notification_by_id
from danswer.db.notification import get_notifications
from danswer.db.notification import update_notification_last_shown
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.key_value_store.factory import get_kv_store
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.server.settings.models import Notification
from danswer.server.settings.models import Settings
from danswer.server.settings.models import UserSettings
@@ -58,9 +58,9 @@ def fetch_settings(
user_notifications = get_user_notifications(user, db_session)
try:
kv_store = get_dynamic_config_store()
kv_store = get_kv_store()
needs_reindexing = cast(bool, kv_store.load(KV_REINDEX_KEY))
except ConfigNotFoundError:
except KvKeyNotFoundError:
needs_reindexing = False
return UserSettings(
@@ -97,7 +97,7 @@ def get_user_notifications(
# Reindexing flag should only be shown to admins, basic users can't trigger it anyway
return []
kv_store = get_dynamic_config_store()
kv_store = get_kv_store()
try:
needs_index = cast(bool, kv_store.load(KV_REINDEX_KEY))
if not needs_index:
@@ -105,7 +105,7 @@ def get_user_notifications(
notif_type=NotificationType.REINDEX, db_session=db_session
)
return []
except ConfigNotFoundError:
except KvKeyNotFoundError:
# If something goes wrong and the flag is gone, better to not start a reindexing
# it's a heavyweight long running job and maybe this flag is cleaned up later
logger.warning("Could not find reindex flag")

View File

@@ -1,16 +1,16 @@
from typing import cast
from danswer.configs.constants import KV_SETTINGS_KEY
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.key_value_store.factory import get_kv_store
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.server.settings.models import Settings
def load_settings() -> Settings:
dynamic_config_store = get_dynamic_config_store()
dynamic_config_store = get_kv_store()
try:
settings = Settings(**cast(dict, dynamic_config_store.load(KV_SETTINGS_KEY)))
except ConfigNotFoundError:
except KvKeyNotFoundError:
settings = Settings()
dynamic_config_store.store(KV_SETTINGS_KEY, settings.model_dump())
@@ -18,4 +18,4 @@ def load_settings() -> Settings:
def store_settings(settings: Settings) -> None:
get_dynamic_config_store().store(KV_SETTINGS_KEY, settings.model_dump())
get_kv_store().store(KV_SETTINGS_KEY, settings.model_dump())

View File

@@ -8,7 +8,7 @@ from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from pydantic import BaseModel
from danswer.dynamic_configs.interface import JSON_ro
from danswer.key_value_store.interface import JSON_ro
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.interfaces import LLM
from danswer.tools.custom.base_tool_types import ToolResultType

View File

@@ -9,7 +9,7 @@ from pydantic import BaseModel
from danswer.chat.chat_utils import combine_message_chain
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.dynamic_configs.interface import JSON_ro
from danswer.key_value_store.interface import JSON_ro
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.headers import build_llm_extra_headers
from danswer.llm.interfaces import LLM

View File

@@ -10,7 +10,7 @@ from danswer.chat.chat_utils import combine_message_chain
from danswer.chat.models import LlmDoc
from danswer.configs.constants import DocumentSource
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.dynamic_configs.interface import JSON_ro
from danswer.key_value_store.interface import JSON_ro
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.interfaces import LLM
from danswer.llm.utils import message_to_string

View File

@@ -16,7 +16,7 @@ from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from danswer.db.models import Persona
from danswer.db.models import User
from danswer.dynamic_configs.interface import JSON_ro
from danswer.key_value_store.interface import JSON_ro
from danswer.llm.answering.models import ContextualPruningConfig
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PreviousMessage

View File

@@ -2,7 +2,7 @@ import abc
from collections.abc import Generator
from typing import Any
from danswer.dynamic_configs.interface import JSON_ro
from danswer.key_value_store.interface import JSON_ro
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.interfaces import LLM
from danswer.tools.models import ToolResponse

View File

@@ -12,8 +12,8 @@ from danswer.configs.constants import KV_CUSTOMER_UUID_KEY
from danswer.configs.constants import KV_INSTANCE_DOMAIN_KEY
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import User
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.key_value_store.factory import get_kv_store
from danswer.key_value_store.interface import KvKeyNotFoundError
_DANSWER_TELEMETRY_ENDPOINT = "https://telemetry.danswer.ai/anonymous_telemetry"
_CACHED_UUID: str | None = None
@@ -34,11 +34,11 @@ def get_or_generate_uuid() -> str:
if _CACHED_UUID is not None:
return _CACHED_UUID
kv_store = get_dynamic_config_store()
kv_store = get_kv_store()
try:
_CACHED_UUID = cast(str, kv_store.load(KV_CUSTOMER_UUID_KEY))
except ConfigNotFoundError:
except KvKeyNotFoundError:
_CACHED_UUID = str(uuid.uuid4())
kv_store.store(KV_CUSTOMER_UUID_KEY, _CACHED_UUID, encrypt=True)
@@ -51,11 +51,11 @@ def _get_or_generate_instance_domain() -> str | None:
if _CACHED_INSTANCE_DOMAIN is not None:
return _CACHED_INSTANCE_DOMAIN
kv_store = get_dynamic_config_store()
kv_store = get_kv_store()
try:
_CACHED_INSTANCE_DOMAIN = cast(str, kv_store.load(KV_INSTANCE_DOMAIN_KEY))
except ConfigNotFoundError:
except KvKeyNotFoundError:
with Session(get_sqlalchemy_engine()) as db_session:
first_user = db_session.query(User).first()
if first_user:

View File

@@ -11,12 +11,25 @@ from danswer.server.settings.store import load_settings
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from ee.danswer.background.celery_utils import should_perform_chat_ttl_check
from ee.danswer.background.celery_utils import should_perform_external_permissions_check
from ee.danswer.background.celery_utils import (
should_perform_external_doc_permissions_check,
)
from ee.danswer.background.celery_utils import (
should_perform_external_group_permissions_check,
)
from ee.danswer.background.task_name_builders import name_chat_ttl_task
from ee.danswer.background.task_name_builders import name_sync_external_permissions_task
from ee.danswer.background.task_name_builders import (
name_sync_external_doc_permissions_task,
)
from ee.danswer.background.task_name_builders import (
name_sync_external_group_permissions_task,
)
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.danswer.external_permissions.permission_sync import (
run_permission_sync_entrypoint,
run_external_doc_permission_sync,
)
from ee.danswer.external_permissions.permission_sync import (
run_external_group_permission_sync,
)
from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report
@@ -26,11 +39,18 @@ logger = setup_logger()
global_version.set_ee()
@build_celery_task_wrapper(name_sync_external_permissions_task)
@build_celery_task_wrapper(name_sync_external_doc_permissions_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_external_permissions_task(cc_pair_id: int) -> None:
def sync_external_doc_permissions_task(cc_pair_id: int) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
run_permission_sync_entrypoint(db_session=db_session, cc_pair_id=cc_pair_id)
run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
@build_celery_task_wrapper(name_sync_external_group_permissions_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_external_group_permissions_task(cc_pair_id: int) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
@build_celery_task_wrapper(name_chat_ttl_task)
@@ -44,18 +64,35 @@ def perform_ttl_management_task(retention_limit_days: int) -> None:
# Periodic Tasks
#####
@celery_app.task(
name="check_sync_external_permissions_task",
name="check_sync_external_doc_permissions_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_sync_external_permissions_task() -> None:
def check_sync_external_doc_permissions_task() -> None:
"""Runs periodically to sync external permissions"""
with Session(get_sqlalchemy_engine()) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
if should_perform_external_permissions_check(
if should_perform_external_doc_permissions_check(
cc_pair=cc_pair, db_session=db_session
):
sync_external_permissions_task.apply_async(
sync_external_doc_permissions_task.apply_async(
kwargs=dict(cc_pair_id=cc_pair.id),
)
@celery_app.task(
name="check_sync_external_group_permissions_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_sync_external_group_permissions_task() -> None:
"""Runs periodically to sync external group permissions"""
with Session(get_sqlalchemy_engine()) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
if should_perform_external_group_permissions_check(
cc_pair=cc_pair, db_session=db_session
):
sync_external_group_permissions_task.apply_async(
kwargs=dict(cc_pair_id=cc_pair.id),
)
@@ -94,9 +131,13 @@ def autogenerate_usage_report_task() -> None:
# Celery Beat (Periodic Tasks) Settings
#####
celery_app.conf.beat_schedule = {
"sync-external-permissions": {
"task": "check_sync_external_permissions_task",
"schedule": timedelta(seconds=60), # TODO: optimize this
"sync-external-doc-permissions": {
"task": "check_sync_external_doc_permissions_task",
"schedule": timedelta(seconds=5), # TODO: optimize this
},
"sync-external-group-permissions": {
"task": "check_sync_external_group_permissions_task",
"schedule": timedelta(seconds=5), # TODO: optimize this
},
"autogenerate_usage_report": {
"task": "autogenerate_usage_report_task",

View File

@@ -0,0 +1,52 @@
from typing import cast
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.utils.logger import setup_logger
from ee.danswer.db.user_group import delete_user_group
from ee.danswer.db.user_group import fetch_user_group
from ee.danswer.db.user_group import mark_user_group_as_synced
logger = setup_logger()
def monitor_usergroup_taskset(key_bytes: bytes, r: Redis, db_session: Session) -> None:
"""This function is likely to move in the worker refactor happening next."""
key = key_bytes.decode("utf-8")
usergroup_id = RedisUserGroup.get_id_from_fence_key(key)
if not usergroup_id:
task_logger.warning("Could not parse usergroup id from {key}")
return
rug = RedisUserGroup(usergroup_id)
fence_value = r.get(rug.fence_key)
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rug.taskset_key))
task_logger.info(
f"User group sync progress: usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
user_group = fetch_user_group(db_session=db_session, user_group_id=usergroup_id)
if user_group:
if user_group.is_up_for_deletion:
delete_user_group(db_session=db_session, user_group=user_group)
task_logger.info(f"Deleted usergroup. id='{usergroup_id}'")
else:
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
task_logger.info(f"Synced usergroup. id='{usergroup_id}'")
r.delete(rug.taskset_key)
r.delete(rug.fence_key)

Some files were not shown because too many files have changed in this diff Show More