Compare commits

...

52 Commits

Author SHA1 Message Date
Chris Weaver
1f0af86454 Temp patch to remove multiple tool calls (#2720) 2024-10-08 13:22:16 -07:00
hagen-danswer
0e6524dd32 Added quotes to project name to handle reserved words (#2639) 2024-10-08 11:20:44 -07:00
rkuo-danswer
2be133d784 Merge pull request #2716 from danswer-ai/hotfix/v0.7-background-logs
backport: rely on stdout redirection for supervisord logging (#2711)
2024-10-07 15:48:43 -07:00
Richard Kuo (Danswer)
cb668bcff5 backport: rely on stdout redirection for supervisord logging (#2711) 2024-10-07 15:13:43 -07:00
rkuo-danswer
756385e3ac Hotfix/v0.7 harden redis (#2683)
* harden redis

* use blockingconnectionpool
2024-10-04 09:49:32 -07:00
Richard Kuo (Danswer)
1966127bd4 trivy workaround 2024-10-03 15:25:54 -07:00
rkuo-danswer
3ac84da698 Merge pull request #2676 from danswer-ai/hotfix/v0.7-vespa-delete-performance
hotfix for vespa delete performance
2024-10-03 10:59:32 -07:00
rkuo-danswer
7c7f5b37f5 Merge pull request #2675 from danswer-ai/hotfix/v0.7-bump-celery
bump celery
2024-10-03 10:59:14 -07:00
Richard Kuo (Danswer)
0bf9243891 Merge branch 'release/v0.7' of github.com:danswer-ai/danswer into hotfix/v0.7-bump-celery 2024-10-03 10:21:41 -07:00
Richard Kuo (Danswer)
cfe4bbe3c7 Merge branch 'release/v0.7' of github.com:danswer-ai/danswer into hotfix/v0.7-vespa-delete-performance 2024-10-03 10:21:23 -07:00
Richard Kuo (Danswer)
9d18b92b90 fix sync checks 2024-10-03 10:20:57 -07:00
Richard Kuo (Danswer)
74315e21b3 bump celery 2024-10-03 09:44:25 -07:00
Richard Kuo (Danswer)
f9a5b227a1 hotfix for vespa delete performance 2024-10-03 09:43:02 -07:00
Chris Weaver
3e511497d2 Fix overflow of prompt library table (#2606) 2024-09-30 15:31:12 +00:00
hagen-danswer
b0056907fb Added permissions syncing for slack (#2602)
* Added permissions syncing for slack

* add no email case handling

* mypy fixes

* frontend

* minor cleanup

* param tweak
2024-09-30 15:14:43 +00:00
Chris Weaver
728a41a35a Add heartbeat to indexing (#2595) 2024-09-29 19:26:40 -07:00
Chris Weaver
ef8dda2d47 Rely on PVC (#2604) 2024-09-29 17:30:39 -07:00
pablodanswer
15283b3140 prevent nextFormStep unless credential fully set up (#2599) 2024-09-29 22:47:45 +00:00
Chris Weaver
e159b2e947 Fix default assistant (#2600)
* Fix default assistant

* Remove log

* Add newline
2024-09-29 22:47:14 +00:00
Jeff Knapp
9155800fab EKS initial deployment (#2154)
Co-authored-by: Chris Weaver <25087905+Weves@users.noreply.github.com>
2024-09-29 15:51:31 -07:00
pablodanswer
a392ef0541 Show transition card if no connectors (#2597)
* show transition card if no connectors

* squash

* update apos
2024-09-29 22:35:41 +00:00
Yuhong Sun
5679f0af61 Minor Query History Fix (#2594) 2024-09-29 10:54:08 -07:00
rkuo-danswer
ff8db71cb5 don't write a nightly tag to the same commit more than once (#2585)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-29 10:36:08 -07:00
hagen-danswer
1cff2b82fd Global Curator Fix + Testing (#2591)
* Global Curator Fix

* test fix
2024-09-28 20:14:39 +00:00
Chris Weaver
50dd3c8beb Add size limit to jira tickets (#2586) 2024-09-28 12:49:13 -07:00
hagen-danswer
66a459234d Minor role display refactor (#2578) 2024-09-27 16:50:03 +00:00
rkuo-danswer
19e57474dc Feature/xenforo (#2497)
* Xenforo forum parser support

* clarify ssl cert reqs

* missed a file

* add isLoadState function, fix up xenforo for data driven connector approach

* fixing a new edge case to skip an unexpected parsed element

* change documentsource to xenforo

* make doc id unique and comment what's happening

* remove stray log line

* address code review

---------

Co-authored-by: sime2408 <simun.sunjic@gmail.com>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-27 16:36:05 +00:00
rkuo-danswer
f9638f2ea5 try user deploy key approach to tagging (#2575)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-27 16:04:55 +00:00
rkuo-danswer
fbf51b70d0 Feature/celery multi (#2470)
* first cut at redis

* some new helper functions for the db

* ignore kombu tables in alembic migrations (used by celery)

* multiline commands for readability, add vespa_metadata_sync queue to worker

* typo fix

* fix returning tuple fields

* add constants

* fix _get_access_for_document

* docstrings!

* fix double function declaration and typing

* fix type hinting

* add a global redis pool

* Add get_document function

* use task_logger in various celery tasks

* add celeryconfig.py to simplify configuration. Will be used in a subsequent commit

* Add celery redis helper. used in a subsequent PR

* kombu warning getting spammy since celery is not self managing its queue in Postgres any more

* add last_modified and last_synced to documents

* fix task naming convention

* use celeryconfig.py

* the big one. adds queues and tasks, updates functions to use the queues with priorities, etc

* change vespa index log line to debug

* mypy fixes

* update alembic migration

* fix fence ordering, rename to "monitor", fix fetch_versioned_implementation call

* mypy

* switch to monotonic time

* fix startup dependencies on redis

* rebase alembic migration

* kombu cleanup - fail silently

* mypy

* add redis_host environment override

* update REDIS_HOST env var in docker-compose.dev.yml

* update the rest of the docker files

* in flight

* harden indexing-status endpoint against db changes happening in the background.  Needs further improvement but OK for now.

* allow no task syncs to run because we create certain objects with no entries but initially marked as out of date

* add back writing to vespa on indexing

* actually working connector deletion

* update contributing guide

* backporting fixes from background_deletion

* renaming cache to cache_volume

* add redis password to various deployments

* try setting up pr testing for helm

* fix indent

* hopefully this release version actually exists

* fix command line option to --chart-dirs

* fetch-depth 0

* edit values.yaml

* try setting ct working directory

* bypass testing only on change for now

* move files and lint them

* update helm testing

* some issues suggest using --config works

* add vespa repo

* add postgresql repo

* increase timeout

* try amd64 runner

* fix redis password reference

* add comment to helm chart testing workflow

* rename helm testing workflow to disable it

* adding clarifying comments

* address code review

* missed a file

* remove commented warning ... just not needed

* fix imports

* refactor to use update_single

* mypy fixes

* add vespa test

* multiple celery workers

* update logs as well and set prefetch multipliers appropriate to the worker intent

* add db refresh to connector deletion

* add some preliminary locking

* organize tasks into separate files

* celery auto associates tasks created inside another task, which bloats the result metadata considerably. trail=False prevents this.

* code review fixes

* move monitor_usergroup_taskset to ee, improve logging

* add multi workers to dev_run_background_jobs.py

* update supervisord with some recommended settings for celery

* name celery workers and shorten dev script prefixing

* add configurable sql alchemy engine settings on startup (needed for various intents like API server, different celery workers and tasks, etc)

* fix comments

* autoscale sqlalchemy pool size to celery concurrency (allow override later?)

* supervisord needs the percent symbols escaped

* use name as primary check, some minor refactoring and type hinting too.

* addressing code review

* fix import

* fix prune_documents_task references

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-27 00:50:55 +00:00
hagen-danswer
b97cc01bb2 Added confluence permission syncing (#2537)
* Added confluence permission syncing

* seperated out group and doc syncing

* minorbugfix and mypy

* added frontend and fixed bug

* Minor refactor

* dealth with confluence rate limits!

* mypy fixes!!!

* addressed yuhong feedback

* primary key fix
2024-09-26 22:10:41 +00:00
rkuo-danswer
6d48fd5d99 clamp retry to max_delay (#2570) 2024-09-26 21:56:46 +00:00
Chris Weaver
1f61447b4b Add open in new tab for custom links (#2568) 2024-09-26 20:01:35 +00:00
rkuo-danswer
deee2b3513 push to docker latest when git tag contains "latest", and tag nightly (#2564)
* comment docker tag latest

* make latest builds contingent on a "latest" keyword in the tag

* v4 checkout

* nightly tag push
2024-09-26 17:40:13 +00:00
hagen-danswer
b73d66c84a Cleaned up foreign key cleanup for user group deletion (#2559)
* cleaned up fk cleanup for user group deletion

* added test for user group deletion
2024-09-26 03:38:01 +00:00
rkuo-danswer
c5a61f4820 Feature/test pruning (#2556)
* add test to exercise pruning

* add prettierignore

* mypy fix

* mypy again

* try getting all the env vars set up correctly

* fix ports and hostnames
2024-09-25 23:34:13 +00:00
pablodanswer
ea4a3cbf86 update folder list (#2563) 2024-09-25 16:25:45 -07:00
rkuo-danswer
166514cedf ssl_ca_certs should default to None, not "". (#2560)
* ssl_ca_certs should default to None, not "".

otherwise, if ssl is enabled it will look for the cert on an empty path and fail.

* mypy fix
2024-09-25 19:56:21 +00:00
pablodanswer
be50ae1e71 flex none (#2558) 2024-09-25 10:19:37 -07:00
pablodanswer
f89504ec53 Update some ux edge cases (#2545)
* update some ux edge cases

* update some formatting / ports
2024-09-25 16:46:43 +00:00
trial-danswer
6b3213b1e4 fix typo (#2543)
* fix typo

* Update EmbeddingFormPage.tsx

---------

Co-authored-by: danswer-trial <danswer-trial@danswer-trials-MacBook-Pro.local>
Co-authored-by: rkuo-danswer <rkuo@danswer.ai>
2024-09-25 01:25:46 +00:00
Chris Weaver
48577bf0e4 Allow = in tag filter (#2548)
* Allow = in tag filter

* Rename func
2024-09-24 21:37:35 +00:00
pablodanswer
c59d1ff0a5 Update merge queue logic (#2554)
* update merge queue logic

* remove space
2024-09-24 18:45:05 +00:00
pablodanswer
ba38dec592 ensure default_assistant passed through 2024-09-24 11:35:19 -07:00
pablodanswer
f5adc3063e Update theming (#2552)
* update theming

* update

* update theming
2024-09-24 18:01:08 +00:00
hagen-danswer
8cfe80c53a Added doc_set__user_group cleanup for user_group deletion (#2551) 2024-09-24 16:09:52 +00:00
ThomaciousD
487250320b fix saml email login upsert issue 2024-09-24 07:42:08 -07:00
rkuo-danswer
c8d13922a9 rename classes and ignore deprecation warnings we mostly don't have c… (#2546)
* rename classes and ignore deprecation warnings we mostly don't have control over

* copy pytest.ini

* ignore CryptographyDeprecationWarning

* fully qualify the warning
2024-09-24 00:21:42 +00:00
rkuo-danswer
cb75449cec Feature/runs on 2 (#2547)
* test self hosted runner

* update more docker builds with self hosted runner

* convert everything to runs-on (except web container)

* try upping the RAM for future flake proofing
2024-09-23 23:46:20 +00:00
rkuo-danswer
b66514cd21 test self hosted runner (#2541)
* test self hosted runner

* update more docker builds with self hosted runner

* convert everything to runs-on (except web container)
2024-09-23 21:57:23 +00:00
Chris Weaver
77650c9ee3 Fix misc tool call errors (#2544)
* Fix misc tool call errors

* Fix middleware
2024-09-23 21:00:48 +00:00
pablodanswer
316b6b99ea Tooling testing (#2533)
* add initial testing

* add custom tool testing

* update ports

* update tests - additional coverage

* update types
2024-09-23 20:09:01 +00:00
Chris Weaver
34c2aa0860 Support svg navigation items (#2542)
* Support SVG nav items

* Handle specifying custom SVGs for navbar

* Add comment

* More comment

* More comment
2024-09-23 13:22:20 -07:00
237 changed files with 36881 additions and 2224 deletions

View File

@@ -7,16 +7,17 @@ on:
env:
REGISTRY_IMAGE: danswer/danswer-backend
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
# TODO: make this a matrix build like the web containers
runs-on:
group: amd64-image-builders
# TODO: investigate a matrix build like the web container
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v2
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -31,7 +32,7 @@ jobs:
run: |
sudo apt-get update
sudo apt-get install -y build-essential
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v5
with:
@@ -41,12 +42,20 @@ jobs:
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.REGISTRY_IMAGE }}:latest
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}

View File

@@ -5,14 +5,18 @@ on:
tags:
- '*'
env:
REGISTRY_IMAGE: danswer/danswer-model-server
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
runs-on:
group: amd64-image-builders
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v2
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -31,13 +35,21 @@ jobs:
platforms: linux/amd64,linux/arm64
push: true
tags: |
danswer/danswer-model-server:${{ github.ref_name }}
danswer/danswer-model-server:latest
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'

View File

@@ -7,7 +7,8 @@ on:
env:
REGISTRY_IMAGE: danswer/danswer-web-server
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build:
runs-on:
@@ -35,7 +36,7 @@ jobs:
images: ${{ env.REGISTRY_IMAGE }}
tags: |
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
type=raw,value=${{ env.REGISTRY_IMAGE }}:latest
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -112,8 +113,16 @@ jobs:
run: |
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'

View File

@@ -1,3 +1,6 @@
# This workflow is set up to be manually triggered via the GitHub Action tab.
# Given a version, it will tag those backend and webserver images as "latest".
name: Tag Latest Version
on:
@@ -9,7 +12,9 @@ on:
jobs:
tag:
runs-on: ubuntu-latest
# See https://runs-on.com/runners/linux/
# use a lower powered instance since this just does i/o to docker hub
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1

View File

@@ -12,7 +12,8 @@ on:
jobs:
lint-test:
runs-on: Amd64
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
# fetch-depth 0 is required for helm/chart-testing-action
steps:

View File

@@ -3,11 +3,14 @@ name: Python Checks
on:
merge_group:
pull_request:
branches: [ main ]
branches:
- main
- 'release/**'
jobs:
mypy-check:
runs-on: ubuntu-latest
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code

View File

@@ -15,10 +15,14 @@ env:
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
# Jira
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
jobs:
connectors-check:
runs-on: ubuntu-latest
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
env:
PYTHONPATH: ./backend

View File

@@ -3,11 +3,14 @@ name: Python Unit Tests
on:
merge_group:
pull_request:
branches: [ main ]
branches:
- main
- 'release/**'
jobs:
backend-check:
runs-on: ubuntu-latest
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
env:
PYTHONPATH: ./backend

View File

@@ -1,6 +1,6 @@
name: Quality Checks PR
concurrency:
group: Quality-Checks-PR-${{ github.head_ref }}
group: Quality-Checks-PR-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
@@ -9,7 +9,8 @@ on:
jobs:
quality-checks:
runs-on: ubuntu-latest
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- uses: actions/checkout@v4
with:

View File

@@ -1,19 +1,22 @@
name: Run Integration Tests
concurrency:
group: Run-Integration-Tests-${{ github.head_ref }}
group: Run-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
merge_group:
pull_request:
branches: [ main ]
branches:
- main
- 'release/**'
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
jobs:
integration-tests:
runs-on: Amd64
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,ram=32,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
@@ -120,6 +123,7 @@ jobs:
run: |
echo "Running integration tests..."
docker run --rm --network danswer-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
@@ -128,6 +132,7 @@ jobs:
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e TEST_WEB_HOSTNAME=test-runner \
danswer/integration-test-runner:it
continue-on-error: true
id: run_tests

54
.github/workflows/tag-nightly.yml vendored Normal file
View File

@@ -0,0 +1,54 @@
name: Nightly Tag Push
on:
schedule:
- cron: '0 0 * * *' # Runs every day at midnight UTC
permissions:
contents: write # Allows pushing tags to the repository
jobs:
create-and-push-tag:
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
# actions using GITHUB_TOKEN cannot trigger another workflow, but we do want this to trigger docker pushes
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
# implement here which needs an actual user's deploy key
- name: Checkout code
uses: actions/checkout@v4
with:
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
- name: Set up Git user
run: |
git config user.name "Richard Kuo [bot]"
git config user.email "rkuo[bot]@danswer.ai"
- name: Check for existing nightly tag
id: check_tag
run: |
if git tag --points-at HEAD --list "nightly-latest*" | grep -q .; then
echo "A tag starting with 'nightly-latest' already exists on HEAD."
echo "tag_exists=true" >> $GITHUB_OUTPUT
else
echo "No tag starting with 'nightly-latest' exists on HEAD."
echo "tag_exists=false" >> $GITHUB_OUTPUT
fi
# don't tag again if HEAD already has a nightly-latest tag on it
- name: Create Nightly Tag
if: steps.check_tag.outputs.tag_exists == 'false'
env:
DATE: ${{ github.run_id }}
run: |
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
echo "Creating tag: $TAG_NAME"
git tag $TAG_NAME
- name: Push Tag
if: steps.check_tag.outputs.tag_exists == 'false'
run: |
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
git push origin $TAG_NAME

1
.prettierignore Normal file
View File

@@ -0,0 +1 @@
backend/tests/integration/tests/pruning/website

View File

@@ -0,0 +1,46 @@
"""fix_user__external_user_group_id_fk
Revision ID: 46b7a812670f
Revises: f32615f71aeb
Create Date: 2024-09-23 12:58:03.894038
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "46b7a812670f"
down_revision = "f32615f71aeb"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop the existing primary key
op.drop_constraint(
"user__external_user_group_id_pkey",
"user__external_user_group_id",
type_="primary",
)
# Add the new composite primary key
op.create_primary_key(
"user__external_user_group_id_pkey",
"user__external_user_group_id",
["user_id", "external_user_group_id", "cc_pair_id"],
)
def downgrade() -> None:
# Drop the composite primary key
op.drop_constraint(
"user__external_user_group_id_pkey",
"user__external_user_group_id",
type_="primary",
)
# Delete all entries from the table
op.execute("DELETE FROM user__external_user_group_id")
# Recreate the original primary key on user_id
op.create_primary_key(
"user__external_user_group_id_pkey", "user__external_user_group_id", ["user_id"]
)

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,6 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from sqlalchemy.orm import Session
@@ -69,6 +70,30 @@ def get_deletion_attempt_snapshot(
)
def skip_cc_pair_pruning_by_task(
pruning_task: TaskQueueState | None, db_session: Session
) -> bool:
"""task should be the latest prune task for this cc_pair"""
if not ALLOW_SIMULTANEOUS_PRUNING:
# if only one prune is allowed at any time, then check to see if any prune
# is active
pruning_type_task_name = name_cc_prune_task()
last_pruning_type_task = get_latest_task_by_type(
pruning_type_task_name, db_session
)
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
last_pruning_type_task, db_session
):
return True
if pruning_task and check_task_is_live_and_not_timed_out(pruning_task, db_session):
# if the last task is live right now, we shouldn't start a new one
return True
return False
def should_prune_cc_pair(
connector: Connector, credential: Credential, db_session: Session
) -> bool:
@@ -79,31 +104,26 @@ def should_prune_cc_pair(
connector_id=connector.id, credential_id=credential.id
)
last_pruning_task = get_latest_task(pruning_task_name, db_session)
if skip_cc_pair_pruning_by_task(last_pruning_task, db_session):
return False
current_db_time = get_db_current_time(db_session)
if not last_pruning_task:
# If the connector has never been pruned, then compare vs when the connector
# was created
time_since_initialization = current_db_time - connector.time_created
if time_since_initialization.total_seconds() >= connector.prune_freq:
return True
return False
if not ALLOW_SIMULTANEOUS_PRUNING:
pruning_type_task_name = name_cc_prune_task()
last_pruning_type_task = get_latest_task_by_type(
pruning_type_task_name, db_session
)
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
last_pruning_type_task, db_session
):
return False
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
return False
if not last_pruning_task.start_time:
# if the last prune task hasn't started, we shouldn't start a new one
return False
# if the last prune task has a start time, then compare against it to determine
# if we should start
time_since_last_pruning = current_db_time - last_pruning_task.start_time
return time_since_last_pruning.total_seconds() >= connector.prune_freq
@@ -141,3 +161,30 @@ def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> se
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
return all_connector_doc_ids
def celery_is_listening_to_queue(worker: Any, name: str) -> bool:
"""Checks to see if we're listening to the named queue"""
# how to get a list of queues this worker is listening to
# https://stackoverflow.com/questions/29790523/how-to-determine-which-queues-a-celery-worker-is-consuming-at-runtime
queue_names = list(worker.app.amqp.queues.consume_from.keys())
for queue_name in queue_names:
if queue_name == name:
return True
return False
def celery_is_worker_primary(worker: Any) -> bool:
"""There are multiple approaches that could be taken, but the way we do it is to
check the hostname set for the celery worker, either in celeryconfig.py or on the
command line."""
hostname = worker.hostname
if hostname.startswith("light"):
return False
if hostname.startswith("heavy"):
return False
return True

View File

@@ -1,7 +1,9 @@
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
from danswer.configs.app_configs import CELERY_BROKER_POOL_LIMIT
from danswer.configs.app_configs import CELERY_RESULT_EXPIRES
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY_RESULT_BACKEND
from danswer.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
from danswer.configs.app_configs import REDIS_HOST
from danswer.configs.app_configs import REDIS_PASSWORD
from danswer.configs.app_configs import REDIS_PORT
@@ -9,6 +11,7 @@ from danswer.configs.app_configs import REDIS_SSL
from danswer.configs.app_configs import REDIS_SSL_CA_CERTS
from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
CELERY_SEPARATOR = ":"
@@ -36,12 +39,30 @@ result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PO
# can stall other tasks.
worker_prefetch_multiplier = 4
broker_connection_retry_on_startup = True
broker_pool_limit = CELERY_BROKER_POOL_LIMIT
# redis broker settings
# https://docs.celeryq.dev/projects/kombu/en/stable/reference/kombu.transport.redis.html
broker_transport_options = {
"priority_steps": list(range(len(DanswerCeleryPriority))),
"sep": CELERY_SEPARATOR,
"queue_order_strategy": "priority",
"retry_on_timeout": True,
"health_check_interval": REDIS_HEALTH_CHECK_INTERVAL,
"socket_keepalive": True,
"socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS,
}
# redis backend settings
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
# there doesn't appear to be a way to set socket_keepalive_options on the redis result backend
redis_socket_keepalive = True
redis_retry_on_timeout = True
redis_backend_health_check_interval = REDIS_HEALTH_CHECK_INTERVAL
task_default_priority = DanswerCeleryPriority.MEDIUM
task_acks_late = True

View File

@@ -0,0 +1,133 @@
import redis
from celery import shared_task
from celery.exceptions import SoftTimeLimitExceeded
from celery.utils.log import get_task_logger
from redis import Redis
from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import ObjectDeletedError
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import IndexingStatus
from danswer.db.index_attempt import get_last_attempt
from danswer.db.models import ConnectorCredentialPair
from danswer.db.search_settings import get_current_search_settings
from danswer.redis.redis_pool import RedisPool
redis_pool = RedisPool()
# use this within celery tasks to get celery task specific logging
task_logger = get_task_logger(__name__)
@shared_task(
name="check_for_connector_deletion_task",
soft_time_limit=JOB_TIMEOUT,
trail=False,
)
def check_for_connector_deletion_task() -> None:
r = redis_pool.get_client()
lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
with Session(get_sqlalchemy_engine()) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
try_generate_document_cc_pair_cleanup_tasks(
cc_pair, db_session, r, lock_beat
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception")
finally:
if lock_beat.owned():
lock_beat.release()
def try_generate_document_cc_pair_cleanup_tasks(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Note that syncing can still be required even if the number of sync tasks generated is zero.
Returns None if no syncing is required.
"""
lock_beat.reacquire()
rcd = RedisConnectorDeletion(cc_pair.id)
# don't generate sync tasks if tasks are still pending
if r.exists(rcd.fence_key):
return None
# we need to refresh the state of the object inside the fence
# to avoid a race condition with db.commit/fence deletion
# at the end of this taskset
try:
db_session.refresh(cc_pair)
except ObjectDeletedError:
return None
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
return None
search_settings = get_current_search_settings(db_session)
last_indexing = get_last_attempt(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
search_settings_id=search_settings.id,
db_session=db_session,
)
if last_indexing:
if (
last_indexing.status == IndexingStatus.IN_PROGRESS
or last_indexing.status == IndexingStatus.NOT_STARTED
):
return None
# add tasks to celery and build up the task set to monitor in redis
r.delete(rcd.taskset_key)
# Add all documents that need to be updated into the queue
task_logger.info(
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
)
tasks_generated = rcd.generate_tasks(celery_app, db_session, r, lock_beat)
if tasks_generated is None:
return None
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
task_logger.info(
f"RedisConnectorDeletion.generate_tasks finished. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
r.set(rcd.fence_key, tasks_generated)
return tasks_generated

View File

@@ -0,0 +1,140 @@
#####
# Periodic Tasks
#####
import json
from typing import Any
from celery import shared_task
from celery.contrib.abortable import AbortableTask # type: ignore
from celery.exceptions import TaskRevokedError
from celery.utils.log import get_task_logger
from sqlalchemy import inspect
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import PostgresAdvisoryLocks
from danswer.db.engine import get_sqlalchemy_engine # type: ignore
# use this within celery tasks to get celery task specific logging
task_logger = get_task_logger(__name__)
@shared_task(
name="kombu_message_cleanup_task",
soft_time_limit=JOB_TIMEOUT,
bind=True,
base=AbortableTask,
)
def kombu_message_cleanup_task(self: Any) -> int:
"""Runs periodically to clean up the kombu_message table"""
# we will select messages older than this amount to clean up
KOMBU_MESSAGE_CLEANUP_AGE = 7 # days
KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT = 1000
ctx = {}
ctx["last_processed_id"] = 0
ctx["deleted"] = 0
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
with Session(get_sqlalchemy_engine()) as db_session:
# Exit the task if we can't take the advisory lock
result = db_session.execute(
text("SELECT pg_try_advisory_lock(:id)"),
{"id": PostgresAdvisoryLocks.KOMBU_MESSAGE_CLEANUP_LOCK_ID.value},
).scalar()
if not result:
return 0
while True:
if self.is_aborted():
raise TaskRevokedError("kombu_message_cleanup_task was aborted.")
b = kombu_message_cleanup_task_helper(ctx, db_session)
if not b:
break
db_session.commit()
if ctx["deleted"] > 0:
task_logger.info(
f"Deleted {ctx['deleted']} orphaned messages from kombu_message."
)
return ctx["deleted"]
def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool:
"""
Helper function to clean up old messages from the `kombu_message` table that are no longer relevant.
This function retrieves messages from the `kombu_message` table that are no longer visible and
older than a specified interval. It checks if the corresponding task_id exists in the
`celery_taskmeta` table. If the task_id does not exist, the message is deleted.
Args:
ctx (dict): A context dictionary containing configuration parameters such as:
- 'cleanup_age' (int): The age in days after which messages are considered old.
- 'page_limit' (int): The maximum number of messages to process in one batch.
- 'last_processed_id' (int): The ID of the last processed message to handle pagination.
- 'deleted' (int): A counter to track the number of deleted messages.
db_session (Session): The SQLAlchemy database session for executing queries.
Returns:
bool: Returns True if there are more rows to process, False if not.
"""
inspector = inspect(db_session.bind)
if not inspector:
return False
# With the move to redis as celery's broker and backend, kombu tables may not even exist.
# We can fail silently.
if not inspector.has_table("kombu_message"):
return False
query = text(
"""
SELECT id, timestamp, payload
FROM kombu_message WHERE visible = 'false'
AND timestamp < CURRENT_TIMESTAMP - INTERVAL :interval_days
AND id > :last_processed_id
ORDER BY id
LIMIT :page_limit
"""
)
kombu_messages = db_session.execute(
query,
{
"interval_days": f"{ctx['cleanup_age']} days",
"page_limit": ctx["page_limit"],
"last_processed_id": ctx["last_processed_id"],
},
).fetchall()
if len(kombu_messages) == 0:
return False
for msg in kombu_messages:
payload = json.loads(msg[2])
task_id = payload["headers"]["id"]
# Check if task_id exists in celery_taskmeta
task_exists = db_session.execute(
text("SELECT 1 FROM celery_taskmeta WHERE task_id = :task_id"),
{"task_id": task_id},
).fetchone()
# If task_id does not exist, delete the message
if not task_exists:
result = db_session.execute(
text("DELETE FROM kombu_message WHERE id = :message_id"),
{"message_id": msg[0]},
)
if result.rowcount > 0: # type: ignore
ctx["deleted"] += 1
ctx["last_processed_id"] = msg[0]
return True

View File

@@ -0,0 +1,120 @@
from celery import shared_task
from celery.utils.log import get_task_logger
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.celery_utils import should_prune_cc_pair
from danswer.background.connector_deletion import delete_connector_credential_pair_batch
from danswer.background.task_utils import build_celery_task_wrapper
from danswer.background.task_utils import name_cc_prune_task
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import InputType
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.engine import get_sqlalchemy_engine
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
# use this within celery tasks to get celery task specific logging
task_logger = get_task_logger(__name__)
@shared_task(
name="check_for_prune_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_prune_task() -> None:
"""Runs periodically to check if any prune tasks should be run and adds them
to the queue"""
with Session(get_sqlalchemy_engine()) as db_session:
all_cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in all_cc_pairs:
if should_prune_cc_pair(
connector=cc_pair.connector,
credential=cc_pair.credential,
db_session=db_session,
):
task_logger.info(f"Pruning the {cc_pair.connector.name} connector")
prune_documents_task.apply_async(
kwargs=dict(
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
)
@build_celery_task_wrapper(name_cc_prune_task)
@celery_app.task(name="prune_documents_task", soft_time_limit=JOB_TIMEOUT)
def prune_documents_task(connector_id: int, credential_id: int) -> None:
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
with Session(get_sqlalchemy_engine()) as db_session:
try:
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair:
task_logger.warning(
f"ccpair not found for {connector_id} {credential_id}"
)
return
runnable_connector = instantiate_connector(
db_session,
cc_pair.connector.source,
InputType.PRUNE,
cc_pair.connector.connector_specific_config,
cc_pair.credential,
)
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
runnable_connector
)
all_indexed_document_ids = {
doc.id
for doc in get_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
}
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
if len(doc_ids_to_remove) == 0:
task_logger.info(
f"No docs to prune from {cc_pair.connector.source} connector"
)
return
task_logger.info(
f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector"
)
delete_connector_credential_pair_batch(
document_ids=doc_ids_to_remove,
connector_id=connector_id,
credential_id=credential_id,
document_index=document_index,
)
except Exception as e:
task_logger.exception(
f"Failed to run pruning for connector id {connector_id}."
)
raise e

View File

@@ -0,0 +1,526 @@
import traceback
from typing import cast
import redis
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from celery.utils.log import get_task_logger
from redis import Redis
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_document
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector_credential_pair import add_deletion_failure_message
from danswer.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit,
)
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.document import count_documents_by_needs_sync
from danswer.db.document import get_document
from danswer.db.document import mark_document_as_synced
from danswer.db.document_set import delete_document_set
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from danswer.db.document_set import fetch_document_sets
from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.models import DocumentSet
from danswer.db.models import UserGroup
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import UpdateRequest
from danswer.redis.redis_pool import RedisPool
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from danswer.utils.variable_functionality import noop_fallback
redis_pool = RedisPool()
# use this within celery tasks to get celery task specific logging
task_logger = get_task_logger(__name__)
# celery auto associates tasks created inside another task,
# which bloats the result metadata considerably. trail=False prevents this.
@shared_task(
name="check_for_vespa_sync_task",
soft_time_limit=JOB_TIMEOUT,
trail=False,
)
def check_for_vespa_sync_task() -> None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
r = redis_pool.get_client()
lock_beat = r.lock(
DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
with Session(get_sqlalchemy_engine()) as db_session:
try_generate_stale_document_sync_tasks(db_session, r, lock_beat)
# check if any document sets are not synced
document_set_info = fetch_document_sets(
user_id=None, db_session=db_session, include_outdated=True
)
for document_set, _ in document_set_info:
try_generate_document_set_sync_tasks(
document_set, db_session, r, lock_beat
)
# check if any user groups are not synced
try:
fetch_user_groups = fetch_versioned_implementation(
"danswer.db.user_group", "fetch_user_groups"
)
user_groups = fetch_user_groups(
db_session=db_session, only_up_to_date=False
)
for usergroup in user_groups:
try_generate_user_group_sync_tasks(
usergroup, db_session, r, lock_beat
)
except ModuleNotFoundError:
# Always exceptions on the MIT version, which is expected
pass
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception")
finally:
if lock_beat.owned():
lock_beat.release()
def try_generate_stale_document_sync_tasks(
db_session: Session, r: Redis, lock_beat: redis.lock.Lock
) -> int | None:
# the fence is up, do nothing
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
return None
r.delete(RedisConnectorCredentialPair.get_taskset_key()) # delete the taskset
# add tasks to celery and build up the task set to monitor in redis
stale_doc_count = count_documents_by_needs_sync(db_session)
if stale_doc_count == 0:
return None
task_logger.info(
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
)
task_logger.info("RedisConnector.generate_tasks starting by cc_pair.")
# rkuo: we could technically sync all stale docs in one big pass.
# but I feel it's more understandable to group the docs by cc_pair
total_tasks_generated = 0
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
rc = RedisConnectorCredentialPair(cc_pair.id)
tasks_generated = rc.generate_tasks(celery_app, db_session, r, lock_beat)
if tasks_generated is None:
continue
if tasks_generated == 0:
continue
task_logger.info(
f"RedisConnector.generate_tasks finished for single cc_pair. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
)
total_tasks_generated += tasks_generated
task_logger.info(
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
)
r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated)
return total_tasks_generated
def try_generate_document_set_sync_tasks(
document_set: DocumentSet, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
) -> int | None:
lock_beat.reacquire()
rds = RedisDocumentSet(document_set.id)
# don't generate document set sync tasks if tasks are still pending
if r.exists(rds.fence_key):
return None
# don't generate sync tasks if we're up to date
# race condition with the monitor/cleanup function if we use a cached result!
db_session.refresh(document_set)
if document_set.is_up_to_date:
return None
# add tasks to celery and build up the task set to monitor in redis
r.delete(rds.taskset_key)
task_logger.info(
f"RedisDocumentSet.generate_tasks starting. document_set_id={document_set.id}"
)
# Add all documents that need to be updated into the queue
tasks_generated = rds.generate_tasks(celery_app, db_session, r, lock_beat)
if tasks_generated is None:
return None
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
task_logger.info(
f"RedisDocumentSet.generate_tasks finished. "
f"document_set_id={document_set.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
r.set(rds.fence_key, tasks_generated)
return tasks_generated
def try_generate_user_group_sync_tasks(
usergroup: UserGroup, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
) -> int | None:
lock_beat.reacquire()
rug = RedisUserGroup(usergroup.id)
# don't generate sync tasks if tasks are still pending
if r.exists(rug.fence_key):
return None
# race condition with the monitor/cleanup function if we use a cached result!
db_session.refresh(usergroup)
if usergroup.is_up_to_date:
return None
# add tasks to celery and build up the task set to monitor in redis
r.delete(rug.taskset_key)
# Add all documents that need to be updated into the queue
task_logger.info(
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
)
tasks_generated = rug.generate_tasks(celery_app, db_session, r, lock_beat)
if tasks_generated is None:
return None
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
task_logger.info(
f"RedisUserGroup.generate_tasks finished. "
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
r.set(rug.fence_key, tasks_generated)
return tasks_generated
def monitor_connector_taskset(r: Redis) -> None:
fence_value = r.get(RedisConnectorCredentialPair.get_fence_key())
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = r.scard(RedisConnectorCredentialPair.get_taskset_key())
task_logger.info(
f"Stale document sync progress: remaining={count} initial={initial_count}"
)
if count == 0:
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
task_logger.info(f"Successfully synced stale documents. count={initial_count}")
def monitor_document_set_taskset(
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key)
if document_set_id is None:
task_logger.warning("could not parse document set id from {key}")
return
rds = RedisDocumentSet(document_set_id)
fence_value = r.get(rds.fence_key)
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rds.taskset_key))
task_logger.info(
f"Document set sync progress: document_set_id={document_set_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
document_set = cast(
DocumentSet,
get_document_set_by_id(db_session=db_session, document_set_id=document_set_id),
) # casting since we "know" a document set with this ID exists
if document_set:
if not document_set.connector_credential_pairs:
# if there are no connectors, then delete the document set.
delete_document_set(document_set_row=document_set, db_session=db_session)
task_logger.info(
f"Successfully deleted document set with ID: '{document_set_id}'!"
)
else:
mark_document_set_as_synced(document_set_id, db_session)
task_logger.info(
f"Successfully synced document set with ID: '{document_set_id}'!"
)
r.delete(rds.taskset_key)
r.delete(rds.fence_key)
def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
if cc_pair_id is None:
task_logger.warning("could not parse document set id from {key}")
return
rcd = RedisConnectorDeletion(cc_pair_id)
fence_value = r.get(rcd.fence_key)
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rcd.taskset_key))
task_logger.info(
f"Connector deletion progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
with Session(get_sqlalchemy_engine()) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if not cc_pair:
return
try:
# clean up the rest of the related Postgres entities
# index attempts
delete_index_attempts(
db_session=db_session,
cc_pair_id=cc_pair.id,
)
# document sets
delete_document_set_cc_pair_relationship__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# user groups
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
"danswer.db.user_group",
"delete_user_group_cc_pair_relationship__no_commit",
noop_fallback,
)
cleanup_user_groups(
cc_pair_id=cc_pair.id,
db_session=db_session,
)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# if there are no credentials left, delete the connector
connector = fetch_connector_by_id(
db_session=db_session,
connector_id=cc_pair.connector_id,
)
if not connector or not len(connector.credentials):
task_logger.info(
"Found no credentials left for connector, deleting connector"
)
db_session.delete(connector)
db_session.commit()
except Exception as e:
stack_trace = traceback.format_exc()
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
add_deletion_failure_message(db_session, cc_pair.id, error_message)
task_logger.exception(
f"Failed to run connector_deletion. "
f"connector_id={cc_pair.connector_id} credential_id={cc_pair.credential_id}"
)
raise e
task_logger.info(
f"Successfully deleted connector_credential_pair with connector_id: '{cc_pair.connector_id}' "
f"and credential_id: '{cc_pair.credential_id}'. "
f"Deleted {initial_count} docs."
)
r.delete(rcd.taskset_key)
r.delete(rcd.fence_key)
@shared_task(name="monitor_vespa_sync", soft_time_limit=300)
def monitor_vespa_sync() -> None:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
It scans for fence values and then gets the counts of any associated tasksets.
If the count is 0, that means all tasks finished and we should clean up.
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
do anything too expensive in this function!
"""
r = redis_pool.get_client()
lock_beat = r.lock(
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# prevent overlapping tasks
if not lock_beat.acquire(blocking=False):
return
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
monitor_connector_deletion_taskset(key_bytes, r)
with Session(get_sqlalchemy_engine()) as db_session:
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
monitor_document_set_taskset(key_bytes, r, db_session)
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
monitor_usergroup_taskset = (
fetch_versioned_implementation_with_fallback(
"danswer.background.celery.tasks.vespa.tasks",
"monitor_usergroup_taskset",
noop_fallback,
)
)
monitor_usergroup_taskset(key_bytes, r, db_session)
# uncomment for debugging if needed
# r_celery = celery_app.broker_connection().channel().client
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
# task_logger.warning(f"queue={DanswerCeleryQueues.VESPA_METADATA_SYNC} length={length}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
finally:
if lock_beat.owned():
lock_beat.release()
@shared_task(
name="vespa_metadata_sync_task",
bind=True,
soft_time_limit=45,
time_limit=60,
max_retries=3,
)
def vespa_metadata_sync_task(self: Task, document_id: str) -> bool:
task_logger.info(f"document_id={document_id}")
try:
with Session(get_sqlalchemy_engine()) as db_session:
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
doc = get_document(document_id, db_session)
if not doc:
return False
# document set sync
doc_sets = fetch_document_sets_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
# User group sync
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
update_request = UpdateRequest(
document_ids=[document_id],
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# update Vespa
document_index.update(update_requests=[update_request])
# update db last. Worst case = we crash right before this and
# the sync might repeat again later
mark_document_as_synced(document_id, db_session)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
except Exception as e:
task_logger.exception("Unexpected exception")
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown)
return True

View File

@@ -10,15 +10,27 @@ are multiple connector / credential pairs that have indexed it
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from celery.utils.log import get_task_logger
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_document
from danswer.access.access import get_access_for_documents
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document
from danswer.db.document import get_document_connector_count
from danswer.db.document import get_document_connector_counts
from danswer.db.document import mark_document_as_synced
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.engine import get_sqlalchemy_engine
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import UpdateRequest
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
@@ -26,6 +38,9 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
# use this within celery tasks to get celery task specific logging
task_logger = get_task_logger(__name__)
_DELETION_BATCH_SIZE = 1000
@@ -108,3 +123,89 @@ def delete_connector_credential_pair_batch(
),
)
db_session.commit()
@shared_task(
name="document_by_cc_pair_cleanup_task",
bind=True,
soft_time_limit=45,
time_limit=60,
max_retries=3,
)
def document_by_cc_pair_cleanup_task(
self: Task, document_id: str, connector_id: int, credential_id: int
) -> bool:
task_logger.info(f"document_id={document_id}")
try:
with Session(get_sqlalchemy_engine()) as db_session:
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
count = get_document_connector_count(db_session, document_id)
if count == 1:
# count == 1 means this is the only remaining cc_pair reference to the doc
# delete it from vespa and the db
document_index.delete_single(doc_id=document_id)
delete_documents_complete__no_commit(
db_session=db_session,
document_ids=[document_id],
)
elif count > 1:
# count > 1 means the document still has cc_pair references
doc = get_document(document_id, db_session)
if not doc:
return False
# the below functions do not include cc_pairs being deleted.
# i.e. they will correctly omit access for the current cc_pair
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
doc_sets = fetch_document_sets_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
update_request = UpdateRequest(
document_ids=[document_id],
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
document_index.update_single(update_request=update_request)
# there are still other cc_pair references to the doc, so just resync to Vespa
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_id=document_id,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
mark_document_as_synced(document_id, db_session)
else:
pass
# update_docs_last_modified__no_commit(
# db_session=db_session,
# document_ids=[document_id],
# )
db_session.commit()
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
except Exception as e:
task_logger.exception("Unexpected exception")
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown)
return True

View File

@@ -29,6 +29,7 @@ from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import IndexAttemptSingleton
from danswer.utils.logger import setup_logger
@@ -103,15 +104,24 @@ def _run_indexing(
)
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings
search_settings=search_settings,
heartbeat=IndexingHeartbeat(
index_attempt_id=index_attempt.id,
db_session=db_session,
# let the world know we're still making progress after
# every 10 batches
freq=10,
),
)
indexing_pipeline = build_indexing_pipeline(
attempt_id=index_attempt.id,
embedder=embedding_model,
document_index=document_index,
ignore_time_skip=index_attempt.from_beginning
or (search_settings.status == IndexModelStatus.FUTURE),
ignore_time_skip=(
index_attempt.from_beginning
or (search_settings.status == IndexModelStatus.FUTURE)
),
db_session=db_session,
)

View File

@@ -416,6 +416,7 @@ def update_loop(
warm_up_bi_encoder(
embedding_model=embedding_model,
)
logger.notice("First inference complete.")
client_primary: Client | SimpleJobClient
client_secondary: Client | SimpleJobClient
@@ -444,6 +445,7 @@ def update_loop(
existing_jobs: dict[int, Future | SimpleJob] = {}
logger.notice("Startup complete. Waiting for indexing jobs...")
while True:
start = time.time()
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")

View File

@@ -164,13 +164,29 @@ REDIS_DB_NUMBER_CELERY_RESULT_BACKEND = int(
)
REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15)) # broker
# will propagate to both our redis client as well as celery's redis client
REDIS_HEALTH_CHECK_INTERVAL = int(os.environ.get("REDIS_HEALTH_CHECK_INTERVAL", 60))
# our redis client only, not celery's
REDIS_POOL_MAX_CONNECTIONS = int(os.environ.get("REDIS_POOL_MAX_CONNECTIONS", 128))
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
# should be one of "required", "optional", or "none"
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "none")
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", "")
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", None)
CELERY_RESULT_EXPIRES = int(os.environ.get("CELERY_RESULT_EXPIRES", 86400)) # seconds
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#broker-pool-limit
# Setting to None may help when there is a proxy in the way closing idle connections
CELERY_BROKER_POOL_LIMIT_DEFAULT = 10
try:
CELERY_BROKER_POOL_LIMIT = int(
os.environ.get("CELERY_BROKER_POOL_LIMIT", CELERY_BROKER_POOL_LIMIT_DEFAULT)
)
except ValueError:
CELERY_BROKER_POOL_LIMIT = CELERY_BROKER_POOL_LIMIT_DEFAULT
#####
# Connector Configs
#####
@@ -247,6 +263,10 @@ JIRA_CONNECTOR_LABELS_TO_SKIP = [
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
if ignored_tag
]
# Maximum size for Jira tickets in bytes (default: 100KB)
JIRA_CONNECTOR_MAX_TICKET_SIZE = int(
os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024)
)
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
@@ -270,7 +290,7 @@ ALLOW_SIMULTANEOUS_PRUNING = (
os.environ.get("ALLOW_SIMULTANEOUS_PRUNING", "").lower() == "true"
)
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
# This is the maximum rate at which documents are queried for a pruning job. 0 disables the limitation.
MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE = int(
os.environ.get("MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE", 0)
)

View File

@@ -1,3 +1,5 @@
import platform
import socket
from enum import auto
from enum import Enum
@@ -34,7 +36,9 @@ POSTGRES_WEB_APP_NAME = "web"
POSTGRES_INDEXER_APP_NAME = "indexer"
POSTGRES_CELERY_APP_NAME = "celery"
POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
POSTGRES_CELERY_WORKER_APP_NAME = "celery_worker"
POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"
@@ -62,6 +66,7 @@ KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings"
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
class DocumentSource(str, Enum):
@@ -104,6 +109,7 @@ class DocumentSource(str, Enum):
R2 = "r2"
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
OCI_STORAGE = "oci_storage"
XENFORO = "xenforo"
NOT_APPLICABLE = "not_applicable"
@@ -186,6 +192,7 @@ class DanswerCeleryQueues:
class DanswerRedisLocks:
PRIMARY_WORKER = "da_lock:primary_worker"
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
@@ -198,3 +205,13 @@ class DanswerCeleryPriority(int, Enum):
MEDIUM = auto()
LOW = auto()
LOWEST = auto()
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3
if platform.system() == "Darwin":
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPALIVE] = 60 # type: ignore
else:
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60 # type: ignore

View File

@@ -0,0 +1,32 @@
import bs4
def build_confluence_document_id(base_url: str, content_url: str) -> str:
"""For confluence, the document id is the page url for a page based document
or the attachment download url for an attachment based document
Args:
base_url (str): The base url of the Confluence instance
content_url (str): The url of the page or attachment download url
Returns:
str: The document id
"""
return f"{base_url}{content_url}"
def get_used_attachments(text: str) -> list[str]:
"""Parse a Confluence html page to generate a list of current
attachment in used
Args:
text (str): The page content
Returns:
list[str]: List of filenames currently in use by the page text
"""
files_in_used = []
soup = bs4.BeautifulSoup(text, "html.parser")
for attachment in soup.findAll("ri:attachment"):
files_in_used.append(attachment.attrs["ri:filename"])
return files_in_used

View File

@@ -22,6 +22,10 @@ from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.confluence.confluence_utils import (
build_confluence_document_id,
)
from danswer.connectors.confluence.confluence_utils import get_used_attachments
from danswer.connectors.confluence.rate_limit_handler import (
make_confluence_call_handle_rate_limit,
)
@@ -105,24 +109,6 @@ def parse_html_page(text: str, confluence_client: Confluence) -> str:
return format_document_soup(soup)
def get_used_attachments(text: str, confluence_client: Confluence) -> list[str]:
"""Parse a Confluence html page to generate a list of current
attachment in used
Args:
text (str): The page content
confluence_client (Confluence): Confluence client
Returns:
list[str]: List of filename currently in used
"""
files_in_used = []
soup = bs4.BeautifulSoup(text, "html.parser")
for attachment in soup.findAll("ri:attachment"):
files_in_used.append(attachment.attrs["ri:filename"])
return files_in_used
def _comment_dfs(
comments_str: str,
comment_pages: Collection[dict[str, Any]],
@@ -624,13 +610,16 @@ class ConfluenceConnector(LoadConnector, PollConnector):
page_html = (
page["body"].get("storage", page["body"].get("view", {})).get("value")
)
page_url = self.wiki_base + page["_links"]["webui"]
# The url and the id are the same
page_url = build_confluence_document_id(
self.wiki_base, page["_links"]["webui"]
)
if not page_html:
logger.debug("Page is empty, skipping: %s", page_url)
continue
page_text = parse_html_page(page_html, self.confluence_client)
files_in_used = get_used_attachments(page_html, self.confluence_client)
files_in_used = get_used_attachments(page_html)
attachment_text, unused_page_attachments = self._fetch_attachments(
self.confluence_client, page_id, files_in_used
)
@@ -683,8 +672,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
if time_filter and not time_filter(last_updated):
continue
attachment_url = self._attachment_to_download_link(
self.confluence_client, attachment
# The url and the id are the same
attachment_url = build_confluence_document_id(
self.wiki_base, attachment["_links"]["download"]
)
attachment_content = self._attachment_to_content(
self.confluence_client, attachment

View File

@@ -50,6 +50,12 @@ def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
pass
if retry_after is not None:
if retry_after > 600:
logger.warning(
f"Clamping retry_after from {retry_after} to {max_delay} seconds..."
)
retry_after = max_delay
logger.warning(
f"Rate limit hit. Retrying after {retry_after} seconds..."
)

View File

@@ -9,6 +9,7 @@ from jira.resources import Issue
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
@@ -134,10 +135,18 @@ def fetch_jira_issues_batch(
else extract_text_from_adf(jira.raw["fields"]["description"])
)
comments = _get_comment_strs(jira, comment_email_blacklist)
semantic_rep = f"{description}\n" + "\n".join(
ticket_content = f"{description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments if comment]
)
# Check ticket size
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
logger.info(
f"Skipping {jira.key} because it exceeds the maximum size of "
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
)
continue
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
people = set()
@@ -180,7 +189,7 @@ def fetch_jira_issues_batch(
doc_batch.append(
Document(
id=page_url,
sections=[Section(link=page_url, text=semantic_rep)],
sections=[Section(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=jira.fields.summary,
doc_updated_at=time_str_to_utc(jira.fields.updated),
@@ -236,10 +245,12 @@ class JiraConnector(LoadConnector, PollConnector):
if self.jira_client is None:
raise ConnectorMissingCredentialError("Jira")
# Quote the project name to handle reserved words
quoted_project = f'"{self.jira_project}"'
start_ind = 0
while True:
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
jql=f"project = {self.jira_project}",
jql=f"project = {quoted_project}",
start_index=start_ind,
jira_client=self.jira_client,
batch_size=self.batch_size,
@@ -267,8 +278,10 @@ class JiraConnector(LoadConnector, PollConnector):
"%Y-%m-%d %H:%M"
)
# Quote the project name to handle reserved words
quoted_project = f'"{self.jira_project}"'
jql = (
f"project = {self.jira_project} AND "
f"project = {quoted_project} AND "
f"updated >= '{start_date_str}' AND "
f"updated <= '{end_date_str}'"
)

View File

@@ -42,6 +42,7 @@ from danswer.connectors.slack.load_connector import SlackLoadConnector
from danswer.connectors.teams.connector import TeamsConnector
from danswer.connectors.web.connector import WebConnector
from danswer.connectors.wikipedia.connector import WikipediaConnector
from danswer.connectors.xenforo.connector import XenforoConnector
from danswer.connectors.zendesk.connector import ZendeskConnector
from danswer.connectors.zulip.connector import ZulipConnector
from danswer.db.credentials import backend_update_credential_json
@@ -62,6 +63,7 @@ def identify_connector_class(
DocumentSource.SLACK: {
InputType.LOAD_STATE: SlackLoadConnector,
InputType.POLL: SlackPollConnector,
InputType.PRUNE: SlackPollConnector,
},
DocumentSource.GITHUB: GithubConnector,
DocumentSource.GMAIL: GmailConnector,
@@ -97,6 +99,7 @@ def identify_connector_class(
DocumentSource.R2: BlobStorageConnector,
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
DocumentSource.OCI_STORAGE: BlobStorageConnector,
DocumentSource.XENFORO: XenforoConnector,
}
connector_by_source = connector_map.get(source, {})

View File

@@ -8,13 +8,12 @@ from typing import cast
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.web import SlackResponse
from danswer.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import IdConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
@@ -23,9 +22,8 @@ from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.slack.utils import expert_info_from_slack_id
from danswer.connectors.slack.utils import get_message_link
from danswer.connectors.slack.utils import make_slack_api_call_logged
from danswer.connectors.slack.utils import make_slack_api_call_paginated
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import make_paginated_slack_api_call_w_retries
from danswer.connectors.slack.utils import make_slack_api_call_w_retries
from danswer.connectors.slack.utils import SlackTextCleaner
from danswer.utils.logger import setup_logger
@@ -38,47 +36,18 @@ MessageType = dict[str, Any]
# list of messages in a thread
ThreadType = list[MessageType]
basic_retry_wrapper = retry_builder()
def _make_paginated_slack_api_call(
call: Callable[..., SlackResponse], **kwargs: Any
) -> Generator[dict[str, Any], None, None]:
return make_slack_api_call_paginated(
basic_retry_wrapper(
make_slack_api_rate_limited(make_slack_api_call_logged(call))
)
)(**kwargs)
def _make_slack_api_call(
call: Callable[..., SlackResponse], **kwargs: Any
) -> SlackResponse:
return basic_retry_wrapper(
make_slack_api_rate_limited(make_slack_api_call_logged(call))
)(**kwargs)
def get_channel_info(client: WebClient, channel_id: str) -> ChannelType:
"""Get information about a channel. Needed to convert channel ID to channel name"""
return _make_slack_api_call(client.conversations_info, channel=channel_id)[0][
"channel"
]
def _get_channels(
def _collect_paginated_channels(
client: WebClient,
exclude_archived: bool,
get_private: bool,
channel_types: list[str],
) -> list[ChannelType]:
channels: list[dict[str, Any]] = []
for result in _make_paginated_slack_api_call(
for result in make_paginated_slack_api_call_w_retries(
client.conversations_list,
exclude_archived=exclude_archived,
# also get private channels the bot is added to
types=["public_channel", "private_channel"]
if get_private
else ["public_channel"],
types=channel_types,
):
channels.extend(result["channels"])
@@ -88,19 +57,38 @@ def _get_channels(
def get_channels(
client: WebClient,
exclude_archived: bool = True,
get_public: bool = True,
get_private: bool = True,
) -> list[ChannelType]:
"""Get all channels in the workspace"""
channels: list[dict[str, Any]] = []
channel_types = []
if get_public:
channel_types.append("public_channel")
if get_private:
channel_types.append("private_channel")
# try getting private channels as well at first
try:
return _get_channels(
client=client, exclude_archived=exclude_archived, get_private=True
channels = _collect_paginated_channels(
client=client,
exclude_archived=exclude_archived,
channel_types=channel_types,
)
except SlackApiError as e:
logger.info(f"Unable to fetch private channels due to - {e}")
logger.info("trying again without private channels")
if get_public:
channel_types = ["public_channel"]
else:
logger.warning("No channels to fetch")
return []
channels = _collect_paginated_channels(
client=client,
exclude_archived=exclude_archived,
channel_types=channel_types,
)
return _get_channels(
client=client, exclude_archived=exclude_archived, get_private=False
)
return channels
def get_channel_messages(
@@ -112,14 +100,14 @@ def get_channel_messages(
"""Get all messages in a channel"""
# join so that the bot can access messages
if not channel["is_member"]:
_make_slack_api_call(
make_slack_api_call_w_retries(
client.conversations_join,
channel=channel["id"],
is_private=channel["is_private"],
)
logger.info(f"Successfully joined '{channel['name']}'")
for result in _make_paginated_slack_api_call(
for result in make_paginated_slack_api_call_w_retries(
client.conversations_history,
channel=channel["id"],
oldest=oldest,
@@ -131,7 +119,7 @@ def get_channel_messages(
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
"""Get all messages in a thread"""
threads: list[MessageType] = []
for result in _make_paginated_slack_api_call(
for result in make_paginated_slack_api_call_w_retries(
client.conversations_replies, channel=channel_id, ts=thread_id
):
threads.extend(result["messages"])
@@ -266,7 +254,7 @@ def filter_channels(
]
def get_all_docs(
def _get_all_docs(
client: WebClient,
workspace: str,
channels: list[str] | None = None,
@@ -328,7 +316,44 @@ def get_all_docs(
)
class SlackPollConnector(PollConnector):
def _get_all_doc_ids(
client: WebClient,
channels: list[str] | None = None,
channel_name_regex_enabled: bool = False,
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
) -> set[str]:
"""
Get all document ids in the workspace, channel by channel
This is pretty identical to get_all_docs, but it returns a set of ids instead of documents
This makes it an order of magnitude faster than get_all_docs
"""
all_channels = get_channels(client)
filtered_channels = filter_channels(
all_channels, channels, channel_name_regex_enabled
)
all_doc_ids = set()
for channel in filtered_channels:
channel_message_batches = get_channel_messages(
client=client,
channel=channel,
)
for message_batch in channel_message_batches:
for message in message_batch:
if msg_filter_func(message):
continue
# The document id is the channel id and the ts of the first message in the thread
# Since we already have the first message of the thread, we dont have to
# fetch the thread for id retrieval, saving time and API calls
all_doc_ids.add(f"{channel['id']}__{message['ts']}")
return all_doc_ids
class SlackPollConnector(PollConnector, IdConnector):
def __init__(
self,
workspace: str,
@@ -349,6 +374,16 @@ class SlackPollConnector(PollConnector):
self.client = WebClient(token=bot_token)
return None
def retrieve_all_source_ids(self) -> set[str]:
if self.client is None:
raise ConnectorMissingCredentialError("Slack")
return _get_all_doc_ids(
client=self.client,
channels=self.channels,
channel_name_regex_enabled=self.channel_regex_enabled,
)
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
@@ -356,7 +391,7 @@ class SlackPollConnector(PollConnector):
raise ConnectorMissingCredentialError("Slack")
documents: list[Document] = []
for document in get_all_docs(
for document in _get_all_docs(
client=self.client,
workspace=self.workspace,
channels=self.channels,

View File

@@ -10,11 +10,13 @@ from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.web import SlackResponse
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.models import BasicExpertInfo
from danswer.utils.logger import setup_logger
logger = setup_logger()
basic_retry_wrapper = retry_builder()
# number of messages we request per page when fetching paginated slack messages
_SLACK_LIMIT = 900
@@ -34,7 +36,7 @@ def get_message_link(
)
def make_slack_api_call_logged(
def _make_slack_api_call_logged(
call: Callable[..., SlackResponse],
) -> Callable[..., SlackResponse]:
@wraps(call)
@@ -47,7 +49,7 @@ def make_slack_api_call_logged(
return logged_call
def make_slack_api_call_paginated(
def _make_slack_api_call_paginated(
call: Callable[..., SlackResponse],
) -> Callable[..., Generator[dict[str, Any], None, None]]:
"""Wraps calls to slack API so that they automatically handle pagination"""
@@ -116,6 +118,24 @@ def make_slack_api_rate_limited(
return rate_limited_call
def make_slack_api_call_w_retries(
call: Callable[..., SlackResponse], **kwargs: Any
) -> SlackResponse:
return basic_retry_wrapper(
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
)(**kwargs)
def make_paginated_slack_api_call_w_retries(
call: Callable[..., SlackResponse], **kwargs: Any
) -> Generator[dict[str, Any], None, None]:
return _make_slack_api_call_paginated(
basic_retry_wrapper(
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
)
)(**kwargs)
def expert_info_from_slack_id(
user_id: str | None,
client: WebClient,

View File

@@ -0,0 +1,244 @@
"""
This is the XenforoConnector class. It is used to connect to a Xenforo forum and load or update documents from the forum.
To use this class, you need to provide the URL of the Xenforo forum board you want to connect to when creating an instance
of the class. The URL should be a string that starts with 'http://' or 'https://', followed by the domain name of the
forum, followed by the board name. For example:
base_url = 'https://www.example.com/forum/boards/some-topic/'
The `load_from_state` method is used to load documents from the forum. It takes an optional `state` parameter, which
can be used to specify a state from which to start loading documents.
"""
import re
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from urllib.parse import urlparse
import pytz
import requests
from bs4 import BeautifulSoup
from bs4 import Tag
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import datetime_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_title(soup: BeautifulSoup) -> str:
el = soup.find("h1", "p-title-value")
if not el:
return ""
title = el.text
for char in (";", ":", "!", "*", "/", "\\", "?", '"', "<", ">", "|"):
title = title.replace(char, "_")
return title
def get_pages(soup: BeautifulSoup, url: str) -> list[str]:
page_tags = soup.select("li.pageNav-page")
page_numbers = []
for button in page_tags:
if re.match(r"^\d+$", button.text):
page_numbers.append(button.text)
max_pages = int(max(page_numbers, key=int)) if page_numbers else 1
all_pages = []
for x in range(1, int(max_pages) + 1):
all_pages.append(f"{url}page-{x}")
return all_pages
def parse_post_date(post_element: BeautifulSoup) -> datetime:
el = post_element.find("time")
if not isinstance(el, Tag) or "datetime" not in el.attrs:
return datetime.utcfromtimestamp(0).replace(tzinfo=timezone.utc)
date_value = el["datetime"]
# Ensure date_value is a string (if it's a list, take the first element)
if isinstance(date_value, list):
date_value = date_value[0]
post_date = datetime.strptime(date_value, "%Y-%m-%dT%H:%M:%S%z")
return datetime_to_utc(post_date)
def scrape_page_posts(
soup: BeautifulSoup,
page_index: int,
url: str,
initial_run: bool,
start_time: datetime,
) -> list:
title = get_title(soup)
documents = []
for post in soup.find_all("div", class_="message-inner"):
post_date = parse_post_date(post)
if initial_run or post_date > start_time:
el = post.find("div", class_="bbWrapper")
if not el:
continue
post_text = el.get_text(strip=True) + "\n"
author_tag = post.find("a", class_="username")
if author_tag is None:
author_tag = post.find("span", class_="username")
author = author_tag.get_text(strip=True) if author_tag else "Deleted author"
formatted_time = post_date.strftime("%Y-%m-%d %H:%M:%S")
# TODO: if a caller calls this for each page of a thread, it may see the
# same post multiple times if there is a sticky post
# that appears on each page of a thread.
# it's important to generate unique doc id's, so page index is part of the
# id. We may want to de-dupe this stuff inside the indexing service.
document = Document(
id=f"{DocumentSource.XENFORO.value}_{title}_{page_index}_{formatted_time}",
sections=[Section(link=url, text=post_text)],
title=title,
source=DocumentSource.XENFORO,
semantic_identifier=title,
primary_owners=[BasicExpertInfo(display_name=author)],
metadata={
"type": "post",
"author": author,
"time": formatted_time,
},
doc_updated_at=post_date,
)
documents.append(document)
return documents
class XenforoConnector(LoadConnector):
# Class variable to track if the connector has been run before
has_been_run_before = False
def __init__(self, base_url: str) -> None:
self.base_url = base_url
self.initial_run = not XenforoConnector.has_been_run_before
self.start = datetime.utcnow().replace(tzinfo=pytz.utc) - timedelta(days=1)
self.cookies: dict[str, str] = {}
# mimic user browser to avoid being blocked by the website (see: https://www.useragents.me/)
self.headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/121.0.0.0 Safari/537.36"
}
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
if credentials:
logger.warning("Unexpected credentials provided for Xenforo Connector")
return None
def load_from_state(self) -> GenerateDocumentsOutput:
# Standardize URL to always end in /.
if self.base_url[-1] != "/":
self.base_url += "/"
# Remove all extra parameters from the end such as page, post.
matches = ("threads/", "boards/", "forums/")
for each in matches:
if each in self.base_url:
try:
self.base_url = self.base_url[
0 : self.base_url.index(
"/", self.base_url.index(each) + len(each)
)
+ 1
]
except ValueError:
pass
doc_batch: list[Document] = []
all_threads = []
# If the URL contains "boards/" or "forums/", find all threads.
if "boards/" in self.base_url or "forums/" in self.base_url:
pages = get_pages(self.requestsite(self.base_url), self.base_url)
# Get all pages on thread_list_page
for pre_count, thread_list_page in enumerate(pages, start=1):
logger.info(
f"Getting pages from thread_list_page.. Current: {pre_count}/{len(pages)}\r"
)
all_threads += self.get_threads(thread_list_page)
# If the URL contains "threads/", add the thread to the list.
elif "threads/" in self.base_url:
all_threads.append(self.base_url)
# Process all threads
for thread_count, thread_url in enumerate(all_threads, start=1):
soup = self.requestsite(thread_url)
if soup is None:
logger.error(f"Failed to load page: {self.base_url}")
continue
pages = get_pages(soup, thread_url)
# Getting all pages for all threads
for page_index, page in enumerate(pages, start=1):
logger.info(
f"Progress: Page {page_index}/{len(pages)} - Thread {thread_count}/{len(all_threads)}\r"
)
soup_page = self.requestsite(page)
doc_batch.extend(
scrape_page_posts(
soup_page, page_index, thread_url, self.initial_run, self.start
)
)
if doc_batch:
yield doc_batch
# Mark the initial run finished after all threads and pages have been processed
XenforoConnector.has_been_run_before = True
def get_threads(self, url: str) -> list[str]:
soup = self.requestsite(url)
thread_tags = soup.find_all(class_="structItem-title")
base_url = "{uri.scheme}://{uri.netloc}".format(uri=urlparse(url))
threads = []
for x in thread_tags:
y = x.find_all(href=True)
for element in y:
link = element["href"]
if "threads/" in link:
stripped = link[0 : link.rfind("/") + 1]
if base_url + stripped not in threads:
threads.append(base_url + stripped)
return threads
def requestsite(self, url: str) -> BeautifulSoup:
try:
response = requests.get(
url, cookies=self.cookies, headers=self.headers, timeout=10
)
if response.status_code != 200:
logger.error(
f"<{url}> Request Error: {response.status_code} - {response.reason}"
)
return BeautifulSoup(response.text, "html.parser")
except TimeoutError:
logger.error("Timed out Error.")
except Exception as e:
logger.error(f"Error on {url}")
logger.exception(e)
return BeautifulSoup("", "html.parser")
if __name__ == "__main__":
connector = XenforoConnector(
# base_url="https://cassiopaea.org/forum/threads/how-to-change-your-emotional-state.41381/"
base_url="https://xenforo.com/community/threads/whats-new-with-enhanced-search-resource-manager-and-media-gallery-in-xenforo-2-3.220935/"
)
document_batches = connector.load_from_state()
print(next(document_batches))

View File

@@ -26,9 +26,7 @@ from danswer.db.models import UserRole
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit
from ee.danswer.external_permissions.permission_sync_function_map import (
check_if_valid_sync_source,
)
from ee.danswer.external_permissions.sync_params import check_if_valid_sync_source
logger = setup_logger()

View File

@@ -104,6 +104,18 @@ def construct_document_select_for_connector_credential_pair(
return stmt
def get_document_ids_for_connector_credential_pair(
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
) -> list[str]:
doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
)
return list(db_session.execute(doc_ids_stmt).scalars().all())
def get_documents_for_connector_credential_pair(
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
) -> Sequence[DbDocument]:
@@ -120,8 +132,8 @@ def get_documents_for_connector_credential_pair(
def get_documents_by_ids(
document_ids: list[str],
db_session: Session,
document_ids: list[str],
) -> list[DbDocument]:
stmt = select(DbDocument).where(DbDocument.id.in_(document_ids))
documents = db_session.execute(stmt).scalars().all()

View File

@@ -1,8 +1,10 @@
import contextlib
import threading
import time
from collections.abc import AsyncGenerator
from collections.abc import Generator
from datetime import datetime
from typing import Any
from typing import ContextManager
from sqlalchemy import event
@@ -32,14 +34,9 @@ logger = setup_logger()
SYNC_DB_API = "psycopg2"
ASYNC_DB_API = "asyncpg"
POSTGRES_APP_NAME = (
POSTGRES_UNKNOWN_APP_NAME # helps to diagnose open connections in postgres
)
# global so we don't create more than one engine per process
# outside of being best practice, this is needed so we can properly pool
# connections and not create a new pool on every request
_SYNC_ENGINE: Engine | None = None
_ASYNC_ENGINE: AsyncEngine | None = None
SessionFactory: sessionmaker[Session] | None = None
@@ -108,6 +105,67 @@ def get_db_current_time(db_session: Session) -> datetime:
return result
class SqlEngine:
"""Class to manage a global sql alchemy engine (needed for proper resource control)
Will eventually subsume most of the standalone functions in this file.
Sync only for now"""
_engine: Engine | None = None
_lock: threading.Lock = threading.Lock()
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
# Default parameters for engine creation
DEFAULT_ENGINE_KWARGS = {
"pool_size": 40,
"max_overflow": 10,
"pool_pre_ping": POSTGRES_POOL_PRE_PING,
"pool_recycle": POSTGRES_POOL_RECYCLE,
}
def __init__(self) -> None:
pass
@classmethod
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
"""Private helper method to create and return an Engine."""
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync"
)
merged_kwargs = {**cls.DEFAULT_ENGINE_KWARGS, **engine_kwargs}
return create_engine(connection_string, **merged_kwargs)
@classmethod
def init_engine(cls, **engine_kwargs: Any) -> None:
"""Allow the caller to init the engine with extra params. Different clients
such as the API server and different celery workers and tasks
need different settings."""
with cls._lock:
if not cls._engine:
cls._engine = cls._init_engine(**engine_kwargs)
@classmethod
def get_engine(cls) -> Engine:
"""Gets the sql alchemy engine. Will init a default engine if init hasn't
already been called. You probably want to init first!"""
if not cls._engine:
with cls._lock:
if not cls._engine:
cls._engine = cls._init_engine()
return cls._engine
@classmethod
def set_app_name(cls, app_name: str) -> None:
"""Class method to set the app name."""
cls._app_name = app_name
@classmethod
def get_app_name(cls) -> str:
"""Class method to get current app name."""
if not cls._app_name:
return ""
return cls._app_name
def build_connection_string(
*,
db_api: str = ASYNC_DB_API,
@@ -125,24 +183,11 @@ def build_connection_string(
def init_sqlalchemy_engine(app_name: str) -> None:
global POSTGRES_APP_NAME
POSTGRES_APP_NAME = app_name
SqlEngine.set_app_name(app_name)
def get_sqlalchemy_engine() -> Engine:
global _SYNC_ENGINE
if _SYNC_ENGINE is None:
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=POSTGRES_APP_NAME + "_sync"
)
_SYNC_ENGINE = create_engine(
connection_string,
pool_size=40,
max_overflow=10,
pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE,
)
return _SYNC_ENGINE
return SqlEngine.get_engine()
def get_sqlalchemy_async_engine() -> AsyncEngine:
@@ -154,7 +199,9 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
_ASYNC_ENGINE = create_async_engine(
connection_string,
connect_args={
"server_settings": {"application_name": POSTGRES_APP_NAME + "_async"}
"server_settings": {
"application_name": SqlEngine.get_app_name() + "_async"
}
},
pool_size=40,
max_overflow=10,

View File

@@ -64,19 +64,12 @@ def upsert_cloud_embedding_provider(
def upsert_llm_provider(
llm_provider: LLMProviderUpsertRequest,
db_session: Session,
is_creation: bool = True,
) -> FullLLMProvider:
existing_llm_provider = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
)
if existing_llm_provider and is_creation:
raise ValueError(f"LLM Provider with name {llm_provider.name} already exists")
if not existing_llm_provider:
if not is_creation:
raise ValueError(
f"LLM Provider with name {llm_provider.name} does not exist"
)
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
db_session.add(existing_llm_provider)

View File

@@ -1725,7 +1725,9 @@ class User__ExternalUserGroupId(Base):
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
# These group ids have been prefixed by the source type
external_user_group_id: Mapped[str] = mapped_column(String, primary_key=True)
cc_pair_id: Mapped[int] = mapped_column(ForeignKey("connector_credential_pair.id"))
cc_pair_id: Mapped[int] = mapped_column(
ForeignKey("connector_credential_pair.id"), primary_key=True
)
class UsageReport(Base):

View File

@@ -1,3 +1,4 @@
from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import or_
@@ -107,12 +108,14 @@ def create_or_add_document_tag_list(
return all_tags
def get_tags_by_value_prefix_for_source_types(
def find_tags(
tag_key_prefix: str | None,
tag_value_prefix: str | None,
sources: list[DocumentSource] | None,
limit: int | None,
db_session: Session,
# if set, both tag_key_prefix and tag_value_prefix must be a match
require_both_to_match: bool = False,
) -> list[Tag]:
query = select(Tag)
@@ -122,7 +125,11 @@ def get_tags_by_value_prefix_for_source_types(
conditions.append(Tag.tag_key.ilike(f"{tag_key_prefix}%"))
if tag_value_prefix:
conditions.append(Tag.tag_value.ilike(f"{tag_value_prefix}%"))
query = query.where(or_(*conditions))
final_prefix_condition = (
and_(*conditions) if require_both_to_match else or_(*conditions)
)
query = query.where(final_prefix_condition)
if sources:
query = query.where(Tag.source.in_(sources))

View File

@@ -1,3 +1,6 @@
from sqlalchemy.orm import Session
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.vespa.index import VespaIndex
@@ -13,3 +16,14 @@ def get_default_document_index(
return VespaIndex(
index_name=primary_index_name, secondary_index_name=secondary_index_name
)
def get_current_primary_default_document_index(db_session: Session) -> DocumentIndex:
"""
TODO: Use redis to cache this or something
"""
search_settings = get_current_search_settings(db_session)
return get_default_document_index(
primary_index_name=search_settings.index_name,
secondary_index_name=None,
)

View File

@@ -156,6 +156,16 @@ class Deletable(abc.ABC):
Class must implement the ability to delete document by their unique document ids.
"""
@abc.abstractmethod
def delete_single(self, doc_id: str) -> None:
"""
Given a single document id, hard delete it from the document index
Parameters:
- doc_id: document id as specified by the connector
"""
raise NotImplementedError
@abc.abstractmethod
def delete(self, doc_ids: list[str]) -> None:
"""

View File

@@ -13,6 +13,7 @@ from typing import cast
import httpx
import requests
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.chat_configs import DOC_TIME_DECAY
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.chat_configs import TITLE_CONTENT_RATIO
@@ -479,6 +480,66 @@ class VespaIndex(DocumentIndex):
document_ids=doc_ids, index_name=index_name, http_client=http_client
)
def delete_single(self, doc_id: str) -> None:
"""Possibly faster overall than the delete method due to using a single
delete call with a selection query."""
# Vespa deletion is poorly documented ... luckily we found this
# https://docs.vespa.ai/en/operations/batch-delete.html#example
doc_id = replace_invalid_doc_id_characters(doc_id)
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficial for
# indexing / updates / deletes since we have to make a large volume of requests.
index_names = [self.index_name]
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
with httpx.Client(http2=True) as http_client:
for index_name in index_names:
params = httpx.QueryParams(
{
"selection": f"{index_name}.document_id=='{doc_id}'",
"cluster": DOCUMENT_INDEX_NAME,
}
)
total_chunks_deleted = 0
while True:
try:
resp = http_client.delete(
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}",
params=params,
)
resp.raise_for_status()
except httpx.HTTPStatusError as e:
logger.error(
f"Failed to delete chunk, details: {e.response.text}"
)
raise
resp_data = resp.json()
if "documentCount" in resp_data:
chunks_deleted = resp_data["documentCount"]
total_chunks_deleted += chunks_deleted
# Check for continuation token to handle pagination
if "continuation" not in resp_data:
break # Exit loop if no continuation token
if not resp_data["continuation"]:
break # Exit loop if continuation token is empty
params = params.set("continuation", resp_data["continuation"])
logger.debug(
f"VespaIndex.delete_single: "
f"index={index_name} "
f"doc={doc_id} "
f"chunks_deleted={total_chunks_deleted}"
)
def id_based_retrieval(
self,
chunk_requests: list[VespaChunkRequest],

View File

@@ -10,6 +10,7 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
get_metadata_keys_to_ignore,
)
from danswer.connectors.models import Document
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.models import DocAwareChunk
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.utils.logger import setup_logger
@@ -123,6 +124,7 @@ class Chunker:
chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE,
chunk_overlap: int = CHUNK_OVERLAP,
mini_chunk_size: int = MINI_CHUNK_SIZE,
heartbeat: Heartbeat | None = None,
) -> None:
from llama_index.text_splitter import SentenceSplitter
@@ -131,6 +133,7 @@ class Chunker:
self.enable_multipass = enable_multipass
self.enable_large_chunks = enable_large_chunks
self.tokenizer = tokenizer
self.heartbeat = heartbeat
self.blurb_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize,
@@ -255,7 +258,7 @@ class Chunker:
# If the chunk does not have any useable content, it will not be indexed
return chunks
def chunk(self, document: Document) -> list[DocAwareChunk]:
def _handle_single_document(self, document: Document) -> list[DocAwareChunk]:
# Specifically for reproducing an issue with gmail
if document.source == DocumentSource.GMAIL:
logger.debug(f"Chunking {document.semantic_identifier}")
@@ -302,3 +305,13 @@ class Chunker:
normal_chunks.extend(large_chunks)
return normal_chunks
def chunk(self, documents: list[Document]) -> list[DocAwareChunk]:
final_chunks: list[DocAwareChunk] = []
for document in documents:
final_chunks.extend(self._handle_single_document(document))
if self.heartbeat:
self.heartbeat.heartbeat()
return final_chunks

View File

@@ -1,12 +1,8 @@
from abc import ABC
from abc import abstractmethod
from sqlalchemy.orm import Session
from danswer.db.models import IndexModelStatus
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
@@ -24,6 +20,9 @@ logger = setup_logger()
class IndexingEmbedder(ABC):
"""Converts chunks into chunks with embeddings. Note that one chunk may have
multiple embeddings associated with it."""
def __init__(
self,
model_name: str,
@@ -33,6 +32,7 @@ class IndexingEmbedder(ABC):
provider_type: EmbeddingProvider | None,
api_key: str | None,
api_url: str | None,
heartbeat: Heartbeat | None,
):
self.model_name = model_name
self.normalize = normalize
@@ -54,6 +54,7 @@ class IndexingEmbedder(ABC):
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
retrim_content=True,
heartbeat=heartbeat,
)
@abstractmethod
@@ -74,6 +75,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type: EmbeddingProvider | None = None,
api_key: str | None = None,
api_url: str | None = None,
heartbeat: Heartbeat | None = None,
):
super().__init__(
model_name,
@@ -83,6 +85,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type,
api_key,
api_url,
heartbeat,
)
@log_function_time()
@@ -166,7 +169,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
title_embed_dict[title] = title_embedding
new_embedded_chunk = IndexChunk(
**chunk.dict(),
**chunk.model_dump(),
embeddings=ChunkEmbedding(
full_embedding=chunk_embeddings[0],
mini_chunk_embeddings=chunk_embeddings[1:],
@@ -180,7 +183,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
@classmethod
def from_db_search_settings(
cls, search_settings: SearchSettings
cls, search_settings: SearchSettings, heartbeat: Heartbeat | None = None
) -> "DefaultIndexingEmbedder":
return cls(
model_name=search_settings.model_name,
@@ -190,28 +193,5 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type=search_settings.provider_type,
api_key=search_settings.api_key,
api_url=search_settings.api_url,
heartbeat=heartbeat,
)
def get_embedding_model_from_search_settings(
db_session: Session, index_model_status: IndexModelStatus = IndexModelStatus.PRESENT
) -> IndexingEmbedder:
search_settings: SearchSettings | None
if index_model_status == IndexModelStatus.PRESENT:
search_settings = get_current_search_settings(db_session)
elif index_model_status == IndexModelStatus.FUTURE:
search_settings = get_secondary_search_settings(db_session)
if not search_settings:
raise RuntimeError("No secondary index configured")
else:
raise RuntimeError("Not supporting embedding model rollbacks")
return DefaultIndexingEmbedder(
model_name=search_settings.model_name,
normalize=search_settings.normalize,
query_prefix=search_settings.query_prefix,
passage_prefix=search_settings.passage_prefix,
provider_type=search_settings.provider_type,
api_key=search_settings.api_key,
api_url=search_settings.api_url,
)

View File

@@ -0,0 +1,41 @@
import abc
from typing import Any
from sqlalchemy import func
from sqlalchemy.orm import Session
from danswer.db.index_attempt import get_index_attempt
from danswer.utils.logger import setup_logger
logger = setup_logger()
class Heartbeat(abc.ABC):
"""Useful for any long-running work that goes through a bunch of items
and needs to occasionally give updates on progress.
e.g. chunking, embedding, updating vespa, etc."""
@abc.abstractmethod
def heartbeat(self, metadata: Any = None) -> None:
raise NotImplementedError
class IndexingHeartbeat(Heartbeat):
def __init__(self, index_attempt_id: int, db_session: Session, freq: int):
self.cnt = 0
self.index_attempt_id = index_attempt_id
self.db_session = db_session
self.freq = freq
def heartbeat(self, metadata: Any = None) -> None:
self.cnt += 1
if self.cnt % self.freq == 0:
index_attempt = get_index_attempt(
db_session=self.db_session, index_attempt_id=self.index_attempt_id
)
if index_attempt:
index_attempt.time_updated = func.now()
self.db_session.commit()
else:
logger.error("Index attempt not found, this should not happen!")

View File

@@ -31,6 +31,7 @@ from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import DocumentMetadata
from danswer.indexing.chunker import Chunker
from danswer.indexing.embedder import IndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.utils.logger import setup_logger
@@ -220,8 +221,8 @@ def index_doc_batch_prepare(
document_ids = [document.id for document in documents]
db_docs: list[DBDocument] = get_documents_by_ids(
document_ids=document_ids,
db_session=db_session,
document_ids=document_ids,
)
# Skip indexing docs that don't have a newer updated at
@@ -283,18 +284,10 @@ def index_doc_batch(
return 0, 0
logger.debug("Starting chunking")
chunks: list[DocAwareChunk] = []
for document in ctx.updatable_docs:
chunks.extend(chunker.chunk(document=document))
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
logger.debug("Starting embedding")
chunks_with_embeddings = (
embedder.embed_chunks(
chunks=chunks,
)
if chunks
else []
)
chunks_with_embeddings = embedder.embed_chunks(chunks) if chunks else []
updatable_ids = [doc.id for doc in ctx.updatable_docs]
@@ -406,6 +399,13 @@ def build_indexing_pipeline(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=multipass,
enable_large_chunks=enable_large_chunks,
# after every doc, update status in case there are a bunch of
# really long docs
heartbeat=IndexingHeartbeat(
index_attempt_id=attempt_id, db_session=db_session, freq=1
)
if attempt_id
else None,
)
return partial(

View File

@@ -1,3 +1,4 @@
import itertools
from collections.abc import Callable
from collections.abc import Iterator
from typing import Any
@@ -315,7 +316,9 @@ class Answer:
yield from self._process_llm_stream(
prompt=prompt,
tools=[tool.tool_definition() for tool in self.tools],
# as of now, we don't support multiple tool calls in sequence, which is why
# we don't need to pass this in here
# tools=[tool.tool_definition() for tool in self.tools],
)
return
@@ -554,8 +557,7 @@ class Answer:
def _stream() -> Iterator[str]:
nonlocal stream_stop_info
yield cast(str, message)
for item in stream:
for item in itertools.chain([message], stream):
if isinstance(item, StreamStopInfo):
stream_stop_info = item
return

View File

@@ -16,6 +16,7 @@ from danswer.configs.model_configs import (
)
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.models import SearchSettings
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.logger import setup_logger
@@ -95,6 +96,7 @@ class EmbeddingModel:
api_url: str | None,
provider_type: EmbeddingProvider | None,
retrim_content: bool = False,
heartbeat: Heartbeat | None = None,
) -> None:
self.api_key = api_key
self.provider_type = provider_type
@@ -107,6 +109,7 @@ class EmbeddingModel:
self.tokenizer = get_tokenizer(
model_name=model_name, provider_type=provider_type
)
self.heartbeat = heartbeat
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
@@ -166,6 +169,9 @@ class EmbeddingModel:
response = self._make_model_server_request(embed_request)
embeddings.extend(response.embeddings)
if self.heartbeat:
self.heartbeat.heartbeat()
return embeddings
def encode(

View File

@@ -3,23 +3,23 @@ from typing import Optional
import redis
from redis.client import Redis
from redis.connection import ConnectionPool
from danswer.configs.app_configs import REDIS_DB_NUMBER
from danswer.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
from danswer.configs.app_configs import REDIS_HOST
from danswer.configs.app_configs import REDIS_PASSWORD
from danswer.configs.app_configs import REDIS_POOL_MAX_CONNECTIONS
from danswer.configs.app_configs import REDIS_PORT
from danswer.configs.app_configs import REDIS_SSL
from danswer.configs.app_configs import REDIS_SSL_CA_CERTS
from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
REDIS_POOL_MAX_CONNECTIONS = 10
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
class RedisPool:
_instance: Optional["RedisPool"] = None
_lock: threading.Lock = threading.Lock()
_pool: ConnectionPool
_pool: redis.BlockingConnectionPool
def __new__(cls) -> "RedisPool":
if not cls._instance:
@@ -42,30 +42,42 @@ class RedisPool:
db: int = REDIS_DB_NUMBER,
password: str = REDIS_PASSWORD,
max_connections: int = REDIS_POOL_MAX_CONNECTIONS,
ssl_ca_certs: str = REDIS_SSL_CA_CERTS,
ssl_ca_certs: str | None = REDIS_SSL_CA_CERTS,
ssl_cert_reqs: str = REDIS_SSL_CERT_REQS,
ssl: bool = False,
) -> redis.ConnectionPool:
) -> redis.BlockingConnectionPool:
"""We use BlockingConnectionPool because it will block and wait for a connection
rather than error if max_connections is reached. This is far more deterministic
behavior and aligned with how we want to use Redis."""
# Using ConnectionPool is not well documented.
# Useful examples: https://github.com/redis/redis-py/issues/780
if ssl:
return redis.ConnectionPool(
return redis.BlockingConnectionPool(
host=host,
port=port,
db=db,
password=password,
max_connections=max_connections,
timeout=None,
health_check_interval=REDIS_HEALTH_CHECK_INTERVAL,
socket_keepalive=True,
socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS,
connection_class=redis.SSLConnection,
ssl_ca_certs=ssl_ca_certs,
ssl_cert_reqs=ssl_cert_reqs,
)
return redis.ConnectionPool(
return redis.BlockingConnectionPool(
host=host,
port=port,
db=db,
password=password,
max_connections=max_connections,
timeout=None,
health_check_interval=REDIS_HEALTH_CHECK_INTERVAL,
socket_keepalive=True,
socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS,
)

View File

@@ -1,4 +1,5 @@
import math
from http import HTTPStatus
from fastapi import APIRouter
from fastapi import Depends
@@ -10,6 +11,8 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
from danswer.background.celery.celery_utils import skip_cc_pair_pruning_by_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.db.connector_credential_pair import add_credential_to_connector
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.connector_credential_pair import remove_credential_from_connector
@@ -26,7 +29,9 @@ from danswer.db.index_attempt import count_index_attempts_for_connector
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
from danswer.db.models import User
from danswer.db.tasks import get_latest_task
from danswer.server.documents.models import CCPairFullInfo
from danswer.server.documents.models import CCPairPruningTask
from danswer.server.documents.models import CCStatusUpdateRequest
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.server.documents.models import ConnectorCredentialPairMetadata
@@ -36,7 +41,6 @@ from danswer.utils.logger import setup_logger
from ee.danswer.db.user_group import validate_user_creation_permissions
logger = setup_logger()
router = APIRouter(prefix="/manage")
@@ -190,6 +194,92 @@ def update_cc_pair_name(
raise HTTPException(status_code=400, detail="Name must be unique")
@router.get("/admin/cc-pair/{cc_pair_id}/prune")
def get_cc_pair_latest_prune(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> CCPairPruningTask:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="Connection not found for current user's permissions",
)
# look up the last prune task for this connector (if it exists)
pruning_task_name = name_cc_prune_task(
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
)
last_pruning_task = get_latest_task(pruning_task_name, db_session)
if not last_pruning_task:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="No pruning task found.",
)
return CCPairPruningTask(
id=last_pruning_task.task_id,
name=last_pruning_task.task_name,
status=last_pruning_task.status,
start_time=last_pruning_task.start_time,
register_time=last_pruning_task.register_time,
)
@router.post("/admin/cc-pair/{cc_pair_id}/prune")
def prune_cc_pair(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> StatusResponse[list[int]]:
# avoiding circular refs
from danswer.background.celery.tasks.pruning.tasks import prune_documents_task
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="Connection not found for current user's permissions",
)
pruning_task_name = name_cc_prune_task(
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
)
last_pruning_task = get_latest_task(pruning_task_name, db_session)
if skip_cc_pair_pruning_by_task(
last_pruning_task,
db_session=db_session,
):
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail="Pruning task already in progress.",
)
logger.info(f"Pruning the {cc_pair.connector.name} connector.")
prune_documents_task.apply_async(
kwargs=dict(
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
)
return StatusResponse(
success=True,
message="Successfully created the pruning task.",
)
@router.put("/connector/{connector_id}/credential/{credential_id}")
def associate_credential_to_connector(
connector_id: int,

View File

@@ -268,6 +268,14 @@ class CCPairFullInfo(BaseModel):
)
class CCPairPruningTask(BaseModel):
id: str
name: str
status: TaskStatus
start_time: datetime | None
register_time: datetime | None
class FailedConnectorIndexingStatus(BaseModel):
"""Simplified version of ConnectorIndexingStatus for failed indexing attempts"""

View File

@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.background.celery.celery_app import celery_app
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DocumentSource
@@ -146,10 +147,6 @@ def create_deletion_attempt_for_connector_id(
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> None:
from danswer.background.celery.celery_app import (
check_for_connector_deletion_task,
)
connector_id = connector_credential_pair_identifier.connector_id
credential_id = connector_credential_pair_identifier.credential_id
@@ -193,8 +190,11 @@ def create_deletion_attempt_for_connector_id(
status=ConnectorCredentialPairStatus.DELETING,
)
# run the beat task to pick up this deletion early
check_for_connector_deletion_task.apply_async(
db_session.commit()
# run the beat task to pick up this deletion from the db immediately
celery_app.send_task(
"check_for_connector_deletion_task",
priority=DanswerCeleryPriority.HIGH,
)

View File

@@ -10,6 +10,7 @@ from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.db.engine import get_session
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.llm import fetch_provider
from danswer.db.llm import remove_llm_provider
from danswer.db.llm import update_default_provider
from danswer.db.llm import upsert_llm_provider
@@ -124,17 +125,26 @@ def list_llm_providers(
def put_llm_provider(
llm_provider: LLMProviderUpsertRequest,
is_creation: bool = Query(
True,
False,
description="True if updating an existing provider, False if creating a new one",
),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> FullLLMProvider:
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
# the result
existing_provider = fetch_provider(db_session, llm_provider.name)
if existing_provider and is_creation:
raise HTTPException(
status_code=400,
detail=f"LLM Provider with name {llm_provider.name} already exists",
)
try:
return upsert_llm_provider(
llm_provider=llm_provider,
db_session=db_session,
is_creation=is_creation,
)
except ValueError as e:
logger.exception("Failed to upsert LLM Provider")

View File

@@ -18,7 +18,7 @@ from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
from danswer.db.tag import get_tags_by_value_prefix_for_source_types
from danswer.db.tag import find_tags
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.vespa.index import VespaIndex
from danswer.one_shot_answer.answer_question import stream_search_answer
@@ -99,12 +99,25 @@ def get_tags(
if not allow_prefix:
raise NotImplementedError("Cannot disable prefix match for now")
db_tags = get_tags_by_value_prefix_for_source_types(
tag_key_prefix=match_pattern,
tag_value_prefix=match_pattern,
key_prefix = match_pattern
value_prefix = match_pattern
require_both_to_match = False
# split on = to allow the user to type in "author=bob"
EQUAL_PAT = "="
if match_pattern and EQUAL_PAT in match_pattern:
split_pattern = match_pattern.split(EQUAL_PAT)
key_prefix = split_pattern[0]
value_prefix = EQUAL_PAT.join(split_pattern[1:])
require_both_to_match = True
db_tags = find_tags(
tag_key_prefix=key_prefix,
tag_value_prefix=value_prefix,
sources=sources,
limit=limit,
db_session=db_session,
require_both_to_match=require_both_to_match,
)
server_tags = [
SourceTag(

View File

@@ -11,12 +11,25 @@ from danswer.server.settings.store import load_settings
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from ee.danswer.background.celery_utils import should_perform_chat_ttl_check
from ee.danswer.background.celery_utils import should_perform_external_permissions_check
from ee.danswer.background.celery_utils import (
should_perform_external_doc_permissions_check,
)
from ee.danswer.background.celery_utils import (
should_perform_external_group_permissions_check,
)
from ee.danswer.background.task_name_builders import name_chat_ttl_task
from ee.danswer.background.task_name_builders import name_sync_external_permissions_task
from ee.danswer.background.task_name_builders import (
name_sync_external_doc_permissions_task,
)
from ee.danswer.background.task_name_builders import (
name_sync_external_group_permissions_task,
)
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.danswer.external_permissions.permission_sync import (
run_permission_sync_entrypoint,
run_external_doc_permission_sync,
)
from ee.danswer.external_permissions.permission_sync import (
run_external_group_permission_sync,
)
from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report
@@ -26,11 +39,18 @@ logger = setup_logger()
global_version.set_ee()
@build_celery_task_wrapper(name_sync_external_permissions_task)
@build_celery_task_wrapper(name_sync_external_doc_permissions_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_external_permissions_task(cc_pair_id: int) -> None:
def sync_external_doc_permissions_task(cc_pair_id: int) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
run_permission_sync_entrypoint(db_session=db_session, cc_pair_id=cc_pair_id)
run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
@build_celery_task_wrapper(name_sync_external_group_permissions_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_external_group_permissions_task(cc_pair_id: int) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
@build_celery_task_wrapper(name_chat_ttl_task)
@@ -44,18 +64,35 @@ def perform_ttl_management_task(retention_limit_days: int) -> None:
# Periodic Tasks
#####
@celery_app.task(
name="check_sync_external_permissions_task",
name="check_sync_external_doc_permissions_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_sync_external_permissions_task() -> None:
def check_sync_external_doc_permissions_task() -> None:
"""Runs periodically to sync external permissions"""
with Session(get_sqlalchemy_engine()) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
if should_perform_external_permissions_check(
if should_perform_external_doc_permissions_check(
cc_pair=cc_pair, db_session=db_session
):
sync_external_permissions_task.apply_async(
sync_external_doc_permissions_task.apply_async(
kwargs=dict(cc_pair_id=cc_pair.id),
)
@celery_app.task(
name="check_sync_external_group_permissions_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_sync_external_group_permissions_task() -> None:
"""Runs periodically to sync external group permissions"""
with Session(get_sqlalchemy_engine()) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
if should_perform_external_group_permissions_check(
cc_pair=cc_pair, db_session=db_session
):
sync_external_group_permissions_task.apply_async(
kwargs=dict(cc_pair_id=cc_pair.id),
)
@@ -94,9 +131,13 @@ def autogenerate_usage_report_task() -> None:
# Celery Beat (Periodic Tasks) Settings
#####
celery_app.conf.beat_schedule = {
"sync-external-permissions": {
"task": "check_sync_external_permissions_task",
"schedule": timedelta(seconds=60), # TODO: optimize this
"sync-external-doc-permissions": {
"task": "check_sync_external_doc_permissions_task",
"schedule": timedelta(seconds=5), # TODO: optimize this
},
"sync-external-group-permissions": {
"task": "check_sync_external_group_permissions_task",
"schedule": timedelta(seconds=5), # TODO: optimize this
},
"autogenerate_usage_report": {
"task": "autogenerate_usage_report_task",

View File

@@ -0,0 +1,52 @@
from typing import cast
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.utils.logger import setup_logger
from ee.danswer.db.user_group import delete_user_group
from ee.danswer.db.user_group import fetch_user_group
from ee.danswer.db.user_group import mark_user_group_as_synced
logger = setup_logger()
def monitor_usergroup_taskset(key_bytes: bytes, r: Redis, db_session: Session) -> None:
"""This function is likely to move in the worker refactor happening next."""
key = key_bytes.decode("utf-8")
usergroup_id = RedisUserGroup.get_id_from_fence_key(key)
if not usergroup_id:
task_logger.warning("Could not parse usergroup id from {key}")
return
rug = RedisUserGroup(usergroup_id)
fence_value = r.get(rug.fence_key)
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rug.taskset_key))
task_logger.info(
f"User group sync: usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
user_group = fetch_user_group(db_session=db_session, user_group_id=usergroup_id)
if user_group:
if user_group.is_up_for_deletion:
delete_user_group(db_session=db_session, user_group=user_group)
task_logger.info(f"Deleted usergroup. id='{usergroup_id}'")
else:
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
task_logger.info(f"Synced usergroup. id='{usergroup_id}'")
r.delete(rug.taskset_key)
r.delete(rug.fence_key)

View File

@@ -1,21 +1,17 @@
from typing import cast
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.enums import AccessType
from danswer.db.models import ConnectorCredentialPair
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.utils.logger import setup_logger
from ee.danswer.background.task_name_builders import name_chat_ttl_task
from ee.danswer.background.task_name_builders import name_sync_external_permissions_task
from ee.danswer.db.user_group import delete_user_group
from ee.danswer.db.user_group import fetch_user_group
from ee.danswer.db.user_group import mark_user_group_as_synced
from ee.danswer.background.task_name_builders import (
name_sync_external_doc_permissions_task,
)
from ee.danswer.background.task_name_builders import (
name_sync_external_group_permissions_task,
)
logger = setup_logger()
@@ -38,13 +34,13 @@ def should_perform_chat_ttl_check(
return True
def should_perform_external_permissions_check(
def should_perform_external_doc_permissions_check(
cc_pair: ConnectorCredentialPair, db_session: Session
) -> bool:
if cc_pair.access_type != AccessType.SYNC:
return False
task_name = name_sync_external_permissions_task(cc_pair_id=cc_pair.id)
task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair.id)
latest_task = get_latest_task(task_name, db_session)
if not latest_task:
@@ -57,41 +53,20 @@ def should_perform_external_permissions_check(
return True
def monitor_usergroup_taskset(key_bytes: bytes, r: Redis) -> None:
"""This function is likely to move in the worker refactor happening next."""
key = key_bytes.decode("utf-8")
usergroup_id = RedisUserGroup.get_id_from_fence_key(key)
if not usergroup_id:
task_logger.warning("Could not parse usergroup id from {key}")
return
def should_perform_external_group_permissions_check(
cc_pair: ConnectorCredentialPair, db_session: Session
) -> bool:
if cc_pair.access_type != AccessType.SYNC:
return False
rug = RedisUserGroup(usergroup_id)
fence_value = r.get(rug.fence_key)
if fence_value is None:
return
task_name = name_sync_external_group_permissions_task(cc_pair_id=cc_pair.id)
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
latest_task = get_latest_task(task_name, db_session)
if not latest_task:
return True
count = cast(int, r.scard(rug.taskset_key))
task_logger.info(
f"User group sync: usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
if check_task_is_live_and_not_timed_out(latest_task, db_session):
logger.debug(f"{task_name} is already being performed. Skipping.")
return False
with Session(get_sqlalchemy_engine()) as db_session:
user_group = fetch_user_group(db_session=db_session, user_group_id=usergroup_id)
if user_group:
if user_group.is_up_for_deletion:
delete_user_group(db_session=db_session, user_group=user_group)
task_logger.info(f"Deleted usergroup. id='{usergroup_id}'")
else:
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
task_logger.info(f"Synced usergroup. id='{usergroup_id}'")
r.delete(rug.taskset_key)
r.delete(rug.fence_key)
return True

View File

@@ -2,5 +2,9 @@ def name_chat_ttl_task(retention_limit_days: int) -> str:
return f"chat_ttl_{retention_limit_days}_days"
def name_sync_external_permissions_task(cc_pair_id: int) -> str:
return f"sync_external_permissions_task__{cc_pair_id}"
def name_sync_external_doc_permissions_task(cc_pair_id: int) -> str:
return f"sync_external_doc_permissions_task__{cc_pair_id}"
def name_sync_external_group_permissions_task(cc_pair_id: int) -> str:
return f"sync_external_group_permissions_task__{cc_pair_id}"

View File

@@ -16,7 +16,9 @@ from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential__UserGroup
from danswer.db.models import Document
from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.models import DocumentSet__UserGroup
from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import Persona__UserGroup
from danswer.db.models import TokenRateLimit__UserGroup
from danswer.db.models import User
from danswer.db.models import User__UserGroup
@@ -32,6 +34,93 @@ from ee.danswer.server.user_group.models import UserGroupUpdate
logger = setup_logger()
def _cleanup_user__user_group_relationships__no_commit(
db_session: Session,
user_group_id: int,
user_ids: list[UUID] | None = None,
) -> None:
"""NOTE: does not commit the transaction."""
where_clause = User__UserGroup.user_group_id == user_group_id
if user_ids:
where_clause &= User__UserGroup.user_id.in_(user_ids)
user__user_group_relationships = db_session.scalars(
select(User__UserGroup).where(where_clause)
).all()
for user__user_group_relationship in user__user_group_relationships:
db_session.delete(user__user_group_relationship)
def _cleanup_credential__user_group_relationships__no_commit(
db_session: Session,
user_group_id: int,
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(Credential__UserGroup).filter(
Credential__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)
def _cleanup_llm_provider__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(LLMProvider__UserGroup).filter(
LLMProvider__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)
def _cleanup_persona__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(Persona__UserGroup).filter(
Persona__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)
def _cleanup_token_rate_limit__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
token_rate_limit__user_group_relationships = db_session.scalars(
select(TokenRateLimit__UserGroup).where(
TokenRateLimit__UserGroup.user_group_id == user_group_id
)
).all()
for (
token_rate_limit__user_group_relationship
) in token_rate_limit__user_group_relationships:
db_session.delete(token_rate_limit__user_group_relationship)
def _cleanup_user_group__cc_pair_relationships__no_commit(
db_session: Session, user_group_id: int, outdated_only: bool
) -> None:
"""NOTE: does not commit the transaction."""
stmt = select(UserGroup__ConnectorCredentialPair).where(
UserGroup__ConnectorCredentialPair.user_group_id == user_group_id
)
if outdated_only:
stmt = stmt.where(
UserGroup__ConnectorCredentialPair.is_current == False # noqa: E712
)
user_group__cc_pair_relationships = db_session.scalars(stmt)
for user_group__cc_pair_relationship in user_group__cc_pair_relationships:
db_session.delete(user_group__cc_pair_relationship)
def _cleanup_document_set__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.execute(
delete(DocumentSet__UserGroup).where(
DocumentSet__UserGroup.user_group_id == user_group_id
)
)
def validate_user_creation_permissions(
db_session: Session,
user: User | None,
@@ -62,8 +151,12 @@ def validate_user_creation_permissions(
status_code=400,
detail=detail,
)
user_curated_groups = fetch_user_groups_for_user(
db_session=db_session, user_id=user.id, only_curator_groups=True
db_session=db_session,
user_id=user.id,
# Global curators can curate all groups they are member of
only_curator_groups=user.role != UserRole.GLOBAL_CURATOR,
)
user_curated_group_ids = set([group.id for group in user_curated_groups])
target_group_ids_set = set(target_group_ids)
@@ -285,42 +378,6 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
return db_user_group
def _cleanup_user__user_group_relationships__no_commit(
db_session: Session,
user_group_id: int,
user_ids: list[UUID] | None = None,
) -> None:
"""NOTE: does not commit the transaction."""
where_clause = User__UserGroup.user_group_id == user_group_id
if user_ids:
where_clause &= User__UserGroup.user_id.in_(user_ids)
user__user_group_relationships = db_session.scalars(
select(User__UserGroup).where(where_clause)
).all()
for user__user_group_relationship in user__user_group_relationships:
db_session.delete(user__user_group_relationship)
def _cleanup_credential__user_group_relationships__no_commit(
db_session: Session,
user_group_id: int,
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(Credential__UserGroup).filter(
Credential__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)
def _cleanup_llm_provider__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(LLMProvider__UserGroup).filter(
LLMProvider__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)
def _mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session: Session, user_group_id: int
) -> None:
@@ -475,21 +532,6 @@ def update_user_group(
return db_user_group
def _cleanup_token_rate_limit__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
token_rate_limit__user_group_relationships = db_session.scalars(
select(TokenRateLimit__UserGroup).where(
TokenRateLimit__UserGroup.user_group_id == user_group_id
)
).all()
for (
token_rate_limit__user_group_relationship
) in token_rate_limit__user_group_relationships:
db_session.delete(token_rate_limit__user_group_relationship)
def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> None:
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
db_user_group = db_session.scalar(stmt)
@@ -498,16 +540,31 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
_check_user_group_is_modifiable(db_user_group)
_mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_credential__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_user__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_mark_user_group__cc_pair_relationships_outdated__no_commit(
_cleanup_token_rate_limit__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_token_rate_limit__user_group_relationships__no_commit(
_cleanup_document_set__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_persona__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_user_group__cc_pair_relationships__no_commit(
db_session=db_session,
user_group_id=user_group_id,
outdated_only=False,
)
_cleanup_llm_provider__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
@@ -516,20 +573,12 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
db_session.commit()
def _cleanup_user_group__cc_pair_relationships__no_commit(
db_session: Session, user_group_id: int, outdated_only: bool
) -> None:
"""NOTE: does not commit the transaction."""
stmt = select(UserGroup__ConnectorCredentialPair).where(
UserGroup__ConnectorCredentialPair.user_group_id == user_group_id
)
if outdated_only:
stmt = stmt.where(
UserGroup__ConnectorCredentialPair.is_current == False # noqa: E712
)
user_group__cc_pair_relationships = db_session.scalars(stmt)
for user_group__cc_pair_relationship in user_group__cc_pair_relationships:
db_session.delete(user_group__cc_pair_relationship)
def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
"""
This assumes that all the fk cleanup has already been done.
"""
db_session.delete(user_group)
db_session.commit()
def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> None:
@@ -541,26 +590,6 @@ def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> Non
db_session.commit()
def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
_cleanup_llm_provider__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group.id
)
_cleanup_user__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group.id
)
_cleanup_user_group__cc_pair_relationships__no_commit(
db_session=db_session,
user_group_id=user_group.id,
outdated_only=False,
)
# need to flush so that we don't get a foreign key error when deleting the user group row
db_session.flush()
db_session.delete(user_group)
db_session.commit()
def delete_user_group_cc_pair_relationship__no_commit(
cc_pair_id: int, db_session: Session
) -> None:

View File

@@ -0,0 +1,18 @@
from typing import Any
from atlassian import Confluence # type:ignore
def build_confluence_client(
connector_specific_config: dict[str, Any], raw_credentials_json: dict[str, Any]
) -> Confluence:
is_cloud = connector_specific_config.get("is_cloud", False)
return Confluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present
url=connector_specific_config["wiki_base"].rstrip("/"),
# passing in username causes issues for Confluence data center
username=raw_credentials_json["confluence_username"] if is_cloud else None,
password=raw_credentials_json["confluence_access_token"] if is_cloud else None,
token=raw_credentials_json["confluence_access_token"] if not is_cloud else None,
)

View File

@@ -1,19 +1,254 @@
from typing import Any
from atlassian import Confluence # type:ignore
from sqlalchemy.orm import Session
from danswer.access.models import ExternalAccess
from danswer.connectors.confluence.confluence_utils import (
build_confluence_document_id,
)
from danswer.connectors.confluence.rate_limit_handler import (
make_confluence_call_handle_rate_limit,
)
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
from ee.danswer.db.document import upsert_document_external_perms__no_commit
from ee.danswer.external_permissions.confluence.confluence_sync_utils import (
build_confluence_client,
)
logger = setup_logger()
_REQUEST_PAGINATION_LIMIT = 100
def _get_space_permissions(
db_session: Session,
confluence_client: Confluence,
space_id: str,
) -> ExternalAccess:
get_space_permissions = make_confluence_call_handle_rate_limit(
confluence_client.get_space_permissions
)
space_permissions = get_space_permissions(space_id).get("permissions", [])
user_emails = set()
# Confluence enforces that group names are unique
group_names = set()
is_externally_public = False
for permission in space_permissions:
subs = permission.get("subjects")
if subs:
# If there are subjects, then there are explicit users or groups with access
if email := subs.get("user", {}).get("results", [{}])[0].get("email"):
user_emails.add(email)
if group_name := subs.get("group", {}).get("results", [{}])[0].get("name"):
group_names.add(group_name)
else:
# If there are no subjects, then the permission is for everyone
if permission.get("operation", {}).get(
"operation"
) == "read" and permission.get("anonymousAccess", False):
# If the permission specifies read access for anonymous users, then
# the space is publicly accessible
is_externally_public = True
batch_add_non_web_user_if_not_exists__no_commit(
db_session=db_session, emails=list(user_emails)
)
return ExternalAccess(
external_user_emails=user_emails,
external_user_group_ids=group_names,
is_public=is_externally_public,
)
def _get_restrictions_for_page(
db_session: Session,
page: dict[str, Any],
space_permissions: ExternalAccess,
) -> ExternalAccess:
"""
WARNING: This function includes no pagination. So if a page is private within
the space and has over 200 users or over 200 groups with explicitly read access,
this function will leave out some users or groups.
200 is a large amount so it is unlikely, but just be aware.
"""
restrictions_json = page.get("restrictions", {})
read_access_dict = restrictions_json.get("read", {}).get("restrictions", {})
read_access_user_jsons = read_access_dict.get("user", {}).get("results", [])
read_access_group_jsons = read_access_dict.get("group", {}).get("results", [])
is_space_public = read_access_user_jsons == [] and read_access_group_jsons == []
if not is_space_public:
read_access_user_emails = [
user["email"] for user in read_access_user_jsons if user.get("email")
]
read_access_groups = [group["name"] for group in read_access_group_jsons]
batch_add_non_web_user_if_not_exists__no_commit(
db_session=db_session, emails=list(read_access_user_emails)
)
external_access = ExternalAccess(
external_user_emails=set(read_access_user_emails),
external_user_group_ids=set(read_access_groups),
is_public=False,
)
else:
external_access = space_permissions
return external_access
def _fetch_attachment_document_ids_for_page_paginated(
confluence_client: Confluence, page: dict[str, Any]
) -> list[str]:
"""
Starts by just extracting the first page of attachments from
the page. If all attachments are in the first page, then
no calls to the api are made from this function.
"""
get_attachments_from_content = make_confluence_call_handle_rate_limit(
confluence_client.get_attachments_from_content
)
attachment_doc_ids = []
attachments_dict = page["children"]["attachment"]
start = 0
while True:
attachments_list = attachments_dict["results"]
attachment_doc_ids.extend(
[
build_confluence_document_id(
base_url=confluence_client.url,
content_url=attachment["_links"]["download"],
)
for attachment in attachments_list
]
)
if "next" not in attachments_dict["_links"]:
break
start += len(attachments_list)
attachments_dict = get_attachments_from_content(
page_id=page["id"],
start=start,
limit=_REQUEST_PAGINATION_LIMIT,
)
return attachment_doc_ids
def _fetch_all_pages_paginated(
confluence_client: Confluence,
space_id: str,
) -> list[dict[str, Any]]:
get_all_pages_from_space = make_confluence_call_handle_rate_limit(
confluence_client.get_all_pages_from_space
)
# For each page, this fetches the page's attachments and restrictions.
expansion_strings = [
"children.attachment",
"restrictions.read.restrictions.user",
"restrictions.read.restrictions.group",
]
expansion_string = ",".join(expansion_strings)
all_pages = []
start = 0
while True:
pages_dict = get_all_pages_from_space(
space=space_id,
start=start,
limit=_REQUEST_PAGINATION_LIMIT,
expand=expansion_string,
)
all_pages.extend(pages_dict)
response_size = len(pages_dict)
if response_size < _REQUEST_PAGINATION_LIMIT:
break
start += response_size
return all_pages
def _fetch_all_page_restrictions_for_space(
db_session: Session,
confluence_client: Confluence,
space_id: str,
space_permissions: ExternalAccess,
) -> dict[str, ExternalAccess]:
all_pages = _fetch_all_pages_paginated(
confluence_client=confluence_client,
space_id=space_id,
)
document_restrictions: dict[str, ExternalAccess] = {}
for page in all_pages:
"""
This assigns the same permissions to all attachments of a page and
the page itself.
This is because the attachments are stored in the same Confluence space as the page.
WARNING: We create a dbDocument entry for all attachments, even though attachments
may not be their own standalone documents. This is likely fine as we just upsert a
document with just permissions.
"""
attachment_document_ids = [
build_confluence_document_id(
base_url=confluence_client.url,
content_url=page["_links"]["webui"],
)
]
attachment_document_ids.extend(
_fetch_attachment_document_ids_for_page_paginated(
confluence_client=confluence_client, page=page
)
)
page_permissions = _get_restrictions_for_page(
db_session=db_session,
page=page,
space_permissions=space_permissions,
)
for attachment_document_id in attachment_document_ids:
document_restrictions[attachment_document_id] = page_permissions
return document_restrictions
def confluence_doc_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
docs_with_additional_info: list[DocsWithAdditionalInfo],
sync_details: dict[str, Any],
) -> None:
logger.debug("Not yet implemented ACL sync for confluence, no-op")
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
it in postgres so that when it gets created later, the permissions are
already populated
"""
confluence_client = build_confluence_client(
cc_pair.connector.connector_specific_config, cc_pair.credential.credential_json
)
space_permissions = _get_space_permissions(
db_session=db_session,
confluence_client=confluence_client,
space_id=cc_pair.connector.connector_specific_config["space"],
)
fresh_doc_permissions = _fetch_all_page_restrictions_for_space(
db_session=db_session,
confluence_client=confluence_client,
space_id=cc_pair.connector.connector_specific_config["space"],
space_permissions=space_permissions,
)
for doc_id, ext_access in fresh_doc_permissions.items():
upsert_document_external_perms__no_commit(
db_session=db_session,
doc_id=doc_id,
external_access=ext_access,
source_type=cc_pair.connector.source,
)

View File

@@ -1,19 +1,107 @@
from typing import Any
from collections.abc import Iterator
from atlassian import Confluence # type:ignore
from requests import HTTPError
from sqlalchemy.orm import Session
from danswer.connectors.confluence.rate_limit_handler import (
make_confluence_call_handle_rate_limit,
)
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
from ee.danswer.external_permissions.confluence.confluence_sync_utils import (
build_confluence_client,
)
logger = setup_logger()
_PAGE_SIZE = 100
def _get_confluence_group_names_paginated(
confluence_client: Confluence,
) -> Iterator[str]:
get_all_groups = make_confluence_call_handle_rate_limit(
confluence_client.get_all_groups
)
start = 0
while True:
try:
groups = get_all_groups(start=start, limit=_PAGE_SIZE)
except HTTPError as e:
if e.response.status_code in (403, 404):
return
raise e
for group in groups:
if group_name := group.get("name"):
yield group_name
if len(groups) < _PAGE_SIZE:
break
start += _PAGE_SIZE
def _get_group_members_email_paginated(
confluence_client: Confluence,
group_name: str,
) -> list[str]:
get_group_members = make_confluence_call_handle_rate_limit(
confluence_client.get_group_members
)
group_member_emails: list[str] = []
start = 0
while True:
try:
members = get_group_members(
group_name=group_name, start=start, limit=_PAGE_SIZE
)
except HTTPError as e:
if e.response.status_code == 403 or e.response.status_code == 404:
return group_member_emails
raise e
group_member_emails.extend(
[member.get("email") for member in members if member.get("email")]
)
if len(members) < _PAGE_SIZE:
break
start += _PAGE_SIZE
return group_member_emails
def confluence_group_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
docs_with_additional_info: list[DocsWithAdditionalInfo],
sync_details: dict[str, Any],
) -> None:
logger.debug("Not yet implemented group sync for confluence, no-op")
confluence_client = build_confluence_client(
cc_pair.connector.connector_specific_config, cc_pair.credential.credential_json
)
danswer_groups: list[ExternalUserGroup] = []
# Confluence enforces that group names are unique
for group_name in _get_confluence_group_names_paginated(confluence_client):
group_member_emails = _get_group_members_email_paginated(
confluence_client, group_name
)
group_members = batch_add_non_web_user_if_not_exists__no_commit(
db_session=db_session, emails=group_member_emails
)
if group_members:
danswer_groups.append(
ExternalUserGroup(
id=group_name, user_ids=[user.id for user in group_members]
)
)
replace_user__ext_group_for_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=cc_pair.id,
group_defs=danswer_groups,
source=cc_pair.connector.source,
)

View File

@@ -1,4 +1,6 @@
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
@@ -8,15 +10,17 @@ from sqlalchemy.orm import Session
from danswer.access.models import ExternalAccess
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.google_drive.connector_auth import (
get_google_drive_creds,
)
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import InputType
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from ee.danswer.db.document import upsert_document_external_perms__no_commit
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. Trying to combat here by adding a very
@@ -27,6 +31,42 @@ add_retries = retry_builder(tries=5, delay=5, max_delay=30)
logger = setup_logger()
def _get_docs_with_additional_info(
db_session: Session,
cc_pair: ConnectorCredentialPair,
) -> dict[str, Any]:
# Get all document ids that need their permissions updated
runnable_connector = instantiate_connector(
db_session=db_session,
source=cc_pair.connector.source,
input_type=InputType.POLL,
connector_specific_config=cc_pair.connector.connector_specific_config,
credential=cc_pair.credential,
)
assert isinstance(runnable_connector, PollConnector)
current_time = datetime.now(timezone.utc)
start_time = (
cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc).timestamp()
if cc_pair.last_time_perm_sync
else 0.0
)
cc_pair.last_time_perm_sync = current_time
doc_batch_generator = runnable_connector.poll_source(
start=start_time, end=current_time.timestamp()
)
docs_with_additional_info = {
doc.id: doc.additional_info
for doc_batch in doc_batch_generator
for doc in doc_batch
}
return docs_with_additional_info
def _fetch_permissions_paginated(
drive_service: Any, drive_file_id: str
) -> Iterator[dict[str, Any]]:
@@ -122,8 +162,6 @@ def _fetch_google_permissions_for_document_id(
def gdrive_doc_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
docs_with_additional_info: list[DocsWithAdditionalInfo],
sync_details: dict[str, Any],
) -> None:
"""
Adds the external permissions to the documents in postgres
@@ -131,10 +169,24 @@ def gdrive_doc_sync(
it in postgres so that when it gets created later, the permissions are
already populated
"""
for doc in docs_with_additional_info:
sync_details = cc_pair.auto_sync_options
if sync_details is None:
logger.error("Sync details not found for Google Drive")
raise ValueError("Sync details not found for Google Drive")
# Here we run the connector to grab all the ids
# this may grab ids before they are indexed but that is fine because
# we create a document in postgres to hold the permissions info
# until the indexing job has a chance to run
docs_with_additional_info = _get_docs_with_additional_info(
db_session=db_session,
cc_pair=cc_pair,
)
for doc_id, doc_additional_info in docs_with_additional_info.items():
ext_access = _fetch_google_permissions_for_document_id(
db_session=db_session,
drive_file_id=doc.additional_info,
drive_file_id=doc_additional_info,
raw_credentials_json=cc_pair.credential.credential_json,
company_google_domains=[
cast(dict[str, str], sync_details)["company_domain"]
@@ -142,7 +194,7 @@ def gdrive_doc_sync(
)
upsert_document_external_perms__no_commit(
db_session=db_session,
doc_id=doc.id,
doc_id=doc_id,
external_access=ext_access,
source_type=cc_pair.connector.source,
)

View File

@@ -17,7 +17,6 @@ from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
logger = setup_logger()
@@ -105,9 +104,12 @@ def _fetch_group_members_paginated(
def gdrive_group_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
docs_with_additional_info: list[DocsWithAdditionalInfo],
sync_details: dict[str, Any],
) -> None:
sync_details = cc_pair.auto_sync_options
if sync_details is None:
logger.error("Sync details not found for Google Drive")
raise ValueError("Sync details not found for Google Drive")
google_drive_creds, _ = get_google_drive_creds(
cc_pair.credential.credential_json,
scopes=FETCH_GROUPS_SCOPES,

View File

@@ -5,31 +5,79 @@ from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_documents
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.factory import get_default_document_index
from danswer.db.document import get_document_ids_for_connector_credential_pair
from danswer.db.models import ConnectorCredentialPair
from danswer.document_index.factory import get_current_primary_default_document_index
from danswer.document_index.interfaces import UpdateRequest
from danswer.utils.logger import setup_logger
from ee.danswer.external_permissions.permission_sync_function_map import (
DOC_PERMISSIONS_FUNC_MAP,
)
from ee.danswer.external_permissions.permission_sync_function_map import (
FULL_FETCH_PERIOD_IN_SECONDS,
)
from ee.danswer.external_permissions.permission_sync_function_map import (
GROUP_PERMISSIONS_FUNC_MAP,
)
from ee.danswer.external_permissions.permission_sync_utils import (
get_docs_with_additional_info,
)
from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
from ee.danswer.external_permissions.sync_params import PERMISSION_SYNC_PERIODS
logger = setup_logger()
def run_permission_sync_entrypoint(
def _is_time_to_run_sync(cc_pair: ConnectorCredentialPair) -> bool:
source_sync_period = PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
if not source_sync_period:
return True
# If the last sync is None, it has never been run so we run the sync
if cc_pair.last_time_perm_sync is None:
return True
last_sync = cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc)
current_time = datetime.now(timezone.utc)
# If the last sync is greater than the full fetch period, we run the sync
if (current_time - last_sync).total_seconds() > source_sync_period:
return True
return False
def run_external_group_permission_sync(
db_session: Session,
cc_pair_id: int,
) -> None:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if cc_pair is None:
raise ValueError(f"No connector credential pair found for id: {cc_pair_id}")
source_type = cc_pair.connector.source
group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
if group_sync_func is None:
# Not all sync connectors support group permissions so this is fine
return
if not _is_time_to_run_sync(cc_pair):
return
try:
# This function updates:
# - the user_email <-> external_user_group_id mapping
# in postgres without committing
logger.debug(f"Syncing groups for {source_type}")
if group_sync_func is not None:
group_sync_func(
db_session,
cc_pair,
)
# update postgres
db_session.commit()
except Exception as e:
logger.error(f"Error updating document index: {e}")
db_session.rollback()
def run_external_doc_permission_sync(
db_session: Session,
cc_pair_id: int,
) -> None:
# TODO: seperate out group and doc sync
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if cc_pair is None:
raise ValueError(f"No connector credential pair found for id: {cc_pair_id}")
@@ -37,90 +85,57 @@ def run_permission_sync_entrypoint(
source_type = cc_pair.connector.source
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
if doc_sync_func is None:
raise ValueError(
f"No permission sync function found for source type: {source_type}"
)
sync_details = cc_pair.auto_sync_options
if sync_details is None:
raise ValueError(f"No auto sync options found for source type: {source_type}")
# If the source type is not polling, we only fetch the permissions every
# _FULL_FETCH_PERIOD_IN_SECONDS seconds
full_fetch_period = FULL_FETCH_PERIOD_IN_SECONDS[source_type]
if full_fetch_period is not None:
last_sync = cc_pair.last_time_perm_sync
if (
last_sync
and (
datetime.now(timezone.utc) - last_sync.replace(tzinfo=timezone.utc)
).total_seconds()
< full_fetch_period
):
return
# Here we run the connector to grab all the ids
# this may grab ids before they are indexed but that is fine because
# we create a document in postgres to hold the permissions info
# until the indexing job has a chance to run
docs_with_additional_info = get_docs_with_additional_info(
db_session=db_session,
cc_pair=cc_pair,
)
# This function updates:
# - the user_email <-> external_user_group_id mapping
# in postgres without committing
logger.debug(f"Syncing groups for {source_type}")
if group_sync_func is not None:
group_sync_func(
db_session,
cc_pair,
docs_with_additional_info,
sync_details,
)
# This function updates:
# - the user_email <-> document mapping
# - the external_user_group_id <-> document mapping
# in postgres without committing
logger.debug(f"Syncing docs for {source_type}")
doc_sync_func(
db_session,
cc_pair,
docs_with_additional_info,
sync_details,
)
# This function fetches the updated access for the documents
# and returns a dictionary of document_ids and access
# This is the access we want to update vespa with
docs_access = get_access_for_documents(
document_ids=[doc.id for doc in docs_with_additional_info],
db_session=db_session,
)
# Then we build the update requests to update vespa
update_reqs = [
UpdateRequest(document_ids=[doc_id], access=doc_access)
for doc_id, doc_access in docs_access.items()
]
# Don't bother sync-ing secondary, it will be sync-ed after switch anyway
search_settings = get_current_search_settings(db_session)
document_index = get_default_document_index(
primary_index_name=search_settings.index_name,
secondary_index_name=None,
)
if not _is_time_to_run_sync(cc_pair):
return
try:
# This function updates:
# - the user_email <-> document mapping
# - the external_user_group_id <-> document mapping
# in postgres without committing
logger.debug(f"Syncing docs for {source_type}")
doc_sync_func(
db_session,
cc_pair,
)
# Get the document ids for the cc pair
document_ids_for_cc_pair = get_document_ids_for_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# This function fetches the updated access for the documents
# and returns a dictionary of document_ids and access
# This is the access we want to update vespa with
docs_access = get_access_for_documents(
document_ids=document_ids_for_cc_pair,
db_session=db_session,
)
# Then we build the update requests to update vespa
update_reqs = [
UpdateRequest(document_ids=[doc_id], access=doc_access)
for doc_id, doc_access in docs_access.items()
]
# Don't bother sync-ing secondary, it will be sync-ed after switch anyway
document_index = get_current_primary_default_document_index(db_session)
# update vespa
document_index.update(update_reqs)
cc_pair.last_time_perm_sync = datetime.now(timezone.utc)
# update postgres
db_session.commit()
except Exception as e:
logger.error(f"Error updating document index: {e}")
logger.error(f"Error Syncing Permissions: {e}")
db_session.rollback()

View File

@@ -1,56 +0,0 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import InputType
from danswer.db.models import ConnectorCredentialPair
from danswer.utils.logger import setup_logger
logger = setup_logger()
class DocsWithAdditionalInfo(BaseModel):
id: str
additional_info: Any
def get_docs_with_additional_info(
db_session: Session,
cc_pair: ConnectorCredentialPair,
) -> list[DocsWithAdditionalInfo]:
# Get all document ids that need their permissions updated
runnable_connector = instantiate_connector(
db_session=db_session,
source=cc_pair.connector.source,
input_type=InputType.POLL,
connector_specific_config=cc_pair.connector.connector_specific_config,
credential=cc_pair.credential,
)
assert isinstance(runnable_connector, PollConnector)
current_time = datetime.now(timezone.utc)
start_time = (
cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc).timestamp()
if cc_pair.last_time_perm_sync
else 0
)
cc_pair.last_time_perm_sync = current_time
doc_batch_generator = runnable_connector.poll_source(
start=start_time, end=current_time.timestamp()
)
docs_with_additional_info = [
DocsWithAdditionalInfo(id=doc.id, additional_info=doc.additional_info)
for doc_batch in doc_batch_generator
for doc in doc_batch
]
logger.debug(f"Docs with additional info: {len(docs_with_additional_info)}")
return docs_with_additional_info

View File

@@ -0,0 +1,192 @@
from slack_sdk import WebClient
from sqlalchemy.orm import Session
from danswer.access.models import ExternalAccess
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.interfaces import IdConnector
from danswer.connectors.models import InputType
from danswer.connectors.slack.connector import get_channels
from danswer.connectors.slack.connector import make_paginated_slack_api_call_w_retries
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from ee.danswer.db.document import upsert_document_external_perms__no_commit
from ee.danswer.external_permissions.slack.utils import fetch_user_id_to_email_map
logger = setup_logger()
def _extract_channel_id_from_doc_id(doc_id: str) -> str:
"""
Extracts the channel ID from a document ID string.
The document ID is expected to be in the format: "{channel_id}__{message_ts}"
Args:
doc_id (str): The document ID string.
Returns:
str: The extracted channel ID.
Raises:
ValueError: If the doc_id doesn't contain the expected separator.
"""
try:
channel_id, _ = doc_id.split("__", 1)
return channel_id
except ValueError:
raise ValueError(f"Invalid doc_id format: {doc_id}")
def _get_slack_document_ids_and_channels(
db_session: Session,
cc_pair: ConnectorCredentialPair,
) -> dict[str, list[str]]:
# Get all document ids that need their permissions updated
runnable_connector = instantiate_connector(
db_session=db_session,
source=cc_pair.connector.source,
input_type=InputType.PRUNE,
connector_specific_config=cc_pair.connector.connector_specific_config,
credential=cc_pair.credential,
)
assert isinstance(runnable_connector, IdConnector)
channel_doc_map: dict[str, list[str]] = {}
for doc_id in runnable_connector.retrieve_all_source_ids():
channel_id = _extract_channel_id_from_doc_id(doc_id)
if channel_id not in channel_doc_map:
channel_doc_map[channel_id] = []
channel_doc_map[channel_id].append(doc_id)
return channel_doc_map
def _fetch_worspace_permissions(
db_session: Session,
user_id_to_email_map: dict[str, str],
) -> ExternalAccess:
user_emails = set()
for email in user_id_to_email_map.values():
user_emails.add(email)
batch_add_non_web_user_if_not_exists__no_commit(db_session, list(user_emails))
return ExternalAccess(
external_user_emails=user_emails,
# No group<->document mapping for slack
external_user_group_ids=set(),
# No way to determine if slack is invite only without enterprise liscense
is_public=False,
)
def _fetch_channel_permissions(
db_session: Session,
slack_client: WebClient,
workspace_permissions: ExternalAccess,
user_id_to_email_map: dict[str, str],
) -> dict[str, ExternalAccess]:
channel_permissions = {}
public_channels = get_channels(
client=slack_client,
get_public=True,
get_private=False,
)
public_channel_ids = [
channel["id"] for channel in public_channels if "id" in channel
]
for channel_id in public_channel_ids:
channel_permissions[channel_id] = workspace_permissions
private_channels = get_channels(
client=slack_client,
get_public=False,
get_private=True,
)
private_channel_ids = [
channel["id"] for channel in private_channels if "id" in channel
]
for channel_id in private_channel_ids:
# Collect all member ids for the channel pagination calls
member_ids = []
for result in make_paginated_slack_api_call_w_retries(
slack_client.conversations_members,
channel=channel_id,
):
member_ids.extend(result.get("members", []))
# Collect all member emails for the channel
member_emails = set()
for member_id in member_ids:
member_email = user_id_to_email_map.get(member_id)
if not member_email:
# If the user is an external user, they wont get returned from the
# conversations_members call so we need to make a separate call to users_info
# and add them to the user_id_to_email_map
member_info = slack_client.users_info(user=member_id)
member_email = member_info["user"]["profile"].get("email")
if not member_email:
# If no email is found, we skip the user
continue
user_id_to_email_map[member_id] = member_email
batch_add_non_web_user_if_not_exists__no_commit(
db_session, [member_email]
)
member_emails.add(member_email)
channel_permissions[channel_id] = ExternalAccess(
external_user_emails=member_emails,
# No group<->document mapping for slack
external_user_group_ids=set(),
# No way to determine if slack is invite only without enterprise liscense
is_public=False,
)
return channel_permissions
def slack_doc_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
) -> None:
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
it in postgres so that when it gets created later, the permissions are
already populated
"""
slack_client = WebClient(
token=cc_pair.credential.credential_json["slack_bot_token"]
)
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
channel_doc_map = _get_slack_document_ids_and_channels(
db_session=db_session,
cc_pair=cc_pair,
)
workspace_permissions = _fetch_worspace_permissions(
db_session=db_session,
user_id_to_email_map=user_id_to_email_map,
)
channel_permissions = _fetch_channel_permissions(
db_session=db_session,
slack_client=slack_client,
workspace_permissions=workspace_permissions,
user_id_to_email_map=user_id_to_email_map,
)
for channel_id, ext_access in channel_permissions.items():
doc_ids = channel_doc_map.get(channel_id)
if not doc_ids:
# No documents found for channel the channel_id
continue
for doc_id in doc_ids:
upsert_document_external_perms__no_commit(
db_session=db_session,
doc_id=doc_id,
external_access=ext_access,
source_type=cc_pair.connector.source,
)

View File

@@ -0,0 +1,92 @@
"""
THIS IS NOT USEFUL OR USED FOR PERMISSION SYNCING
WHEN USERGROUPS ARE ADDED TO A CHANNEL, IT JUST RESOLVES ALL THE USERS TO THAT CHANNEL
SO WHEN CHECKING IF A USER CAN ACCESS A DOCUMENT, WE ONLY NEED TO CHECK THEIR EMAIL
THERE IS NO USERGROUP <-> DOCUMENT PERMISSION MAPPING
"""
from slack_sdk import WebClient
from sqlalchemy.orm import Session
from danswer.connectors.slack.connector import make_paginated_slack_api_call_w_retries
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
from ee.danswer.external_permissions.slack.utils import fetch_user_id_to_email_map
logger = setup_logger()
def _get_slack_group_ids(
slack_client: WebClient,
) -> list[str]:
group_ids = []
for result in make_paginated_slack_api_call_w_retries(slack_client.usergroups_list):
for group in result.get("usergroups", []):
group_ids.append(group.get("id"))
return group_ids
def _get_slack_group_members_email(
db_session: Session,
slack_client: WebClient,
group_name: str,
user_id_to_email_map: dict[str, str],
) -> list[str]:
group_member_emails = []
for result in make_paginated_slack_api_call_w_retries(
slack_client.usergroups_users_list, usergroup=group_name
):
for member_id in result.get("users", []):
member_email = user_id_to_email_map.get(member_id)
if not member_email:
# If the user is an external user, they wont get returned from the
# conversations_members call so we need to make a separate call to users_info
member_info = slack_client.users_info(user=member_id)
member_email = member_info["user"]["profile"].get("email")
if not member_email:
# If no email is found, we skip the user
continue
user_id_to_email_map[member_id] = member_email
batch_add_non_web_user_if_not_exists__no_commit(
db_session, [member_email]
)
group_member_emails.append(member_email)
return group_member_emails
def slack_group_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
) -> None:
slack_client = WebClient(
token=cc_pair.credential.credential_json["slack_bot_token"]
)
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
danswer_groups: list[ExternalUserGroup] = []
for group_name in _get_slack_group_ids(slack_client):
group_member_emails = _get_slack_group_members_email(
db_session=db_session,
slack_client=slack_client,
group_name=group_name,
user_id_to_email_map=user_id_to_email_map,
)
group_members = batch_add_non_web_user_if_not_exists__no_commit(
db_session=db_session, emails=group_member_emails
)
if group_members:
danswer_groups.append(
ExternalUserGroup(
id=group_name, user_ids=[user.id for user in group_members]
)
)
replace_user__ext_group_for_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=cc_pair.id,
group_defs=danswer_groups,
source=cc_pair.connector.source,
)

View File

@@ -0,0 +1,18 @@
from slack_sdk import WebClient
from danswer.connectors.slack.connector import make_paginated_slack_api_call_w_retries
def fetch_user_id_to_email_map(
slack_client: WebClient,
) -> dict[str, str]:
user_id_to_email_map = {}
for user_info in make_paginated_slack_api_call_w_retries(
slack_client.users_list,
):
for user in user_info.get("members", []):
if user.get("profile", {}).get("email"):
user_id_to_email_map[user.get("id")] = user.get("profile", {}).get(
"email"
)
return user_id_to_email_map

View File

@@ -1,5 +1,4 @@
from collections.abc import Callable
from typing import Any
from sqlalchemy.orm import Session
@@ -9,15 +8,14 @@ from ee.danswer.external_permissions.confluence.doc_sync import confluence_doc_s
from ee.danswer.external_permissions.confluence.group_sync import confluence_group_sync
from ee.danswer.external_permissions.google_drive.doc_sync import gdrive_doc_sync
from ee.danswer.external_permissions.google_drive.group_sync import gdrive_group_sync
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
from ee.danswer.external_permissions.slack.doc_sync import slack_doc_sync
GroupSyncFuncType = Callable[
[Session, ConnectorCredentialPair, list[DocsWithAdditionalInfo], dict[str, Any]],
None,
]
DocSyncFuncType = Callable[
[Session, ConnectorCredentialPair, list[DocsWithAdditionalInfo], dict[str, Any]],
# Defining the input/output types for the sync functions
SyncFuncType = Callable[
[
Session,
ConnectorCredentialPair,
],
None,
]
@@ -26,27 +24,27 @@ DocSyncFuncType = Callable[
# - the external_user_group_id <-> document mapping
# in postgres without committing
# THIS ONE IS NECESSARY FOR AUTO SYNC TO WORK
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, DocSyncFuncType] = {
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
DocumentSource.GOOGLE_DRIVE: gdrive_doc_sync,
DocumentSource.CONFLUENCE: confluence_doc_sync,
DocumentSource.SLACK: slack_doc_sync,
}
# These functions update:
# - the user_email <-> external_user_group_id mapping
# in postgres without committing
# THIS ONE IS OPTIONAL ON AN APP BY APP BASIS
GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, GroupSyncFuncType] = {
GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
DocumentSource.GOOGLE_DRIVE: gdrive_group_sync,
DocumentSource.CONFLUENCE: confluence_group_sync,
}
# None means that the connector supports polling from last_time_perm_sync to now
FULL_FETCH_PERIOD_IN_SECONDS: dict[DocumentSource, int | None] = {
# Polling is supported
DocumentSource.GOOGLE_DRIVE: None,
# Polling is not supported so we fetch all doc permissions every 10 minutes
DocumentSource.CONFLUENCE: 10 * 60,
# If nothing is specified here, we run the doc_sync every time the celery beat runs
PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all doc permissions every 5 minutes
DocumentSource.CONFLUENCE: 5 * 60,
DocumentSource.SLACK: 5 * 60,
}

View File

@@ -1,3 +1,4 @@
from typing import Any
from typing import List
from pydantic import BaseModel
@@ -6,8 +7,20 @@ from pydantic import Field
class NavigationItem(BaseModel):
link: str
icon: str
title: str
# Right now must be one of the FA icons
icon: str | None = None
# NOTE: SVG must not have a width / height specified
# This is the actual SVG as a string. Done this way to reduce
# complexity / having to store additional "logos" in Postgres
svg_logo: str | None = None
@classmethod
def model_validate(cls, *args: Any, **kwargs: Any) -> "NavigationItem":
instance = super().model_validate(*args, **kwargs)
if bool(instance.icon) == bool(instance.svg_logo):
raise ValueError("Exactly one of fa_icon or svg_logo must be specified")
return instance
class EnterpriseSettings(BaseModel):

View File

@@ -12,7 +12,6 @@ from fastapi_users import exceptions
from fastapi_users.password import PasswordHelper
from onelogin.saml2.auth import OneLogin_Saml2_Auth # type: ignore
from pydantic import BaseModel
from pydantic import EmailStr
from sqlalchemy.orm import Session
from danswer.auth.schemas import UserCreate
@@ -61,7 +60,7 @@ async def upsert_saml_user(email: str) -> User:
user: User = await user_manager.create(
UserCreate(
email=EmailStr(email),
email=email,
password=hashed_pass,
is_verified=True,
role=role,

View File

@@ -1,5 +1,6 @@
import json
import os
from copy import deepcopy
from typing import List
from typing import Optional
@@ -22,6 +23,7 @@ from ee.danswer.db.standard_answer import (
)
from ee.danswer.server.enterprise_settings.models import AnalyticsScriptUpload
from ee.danswer.server.enterprise_settings.models import EnterpriseSettings
from ee.danswer.server.enterprise_settings.models import NavigationItem
from ee.danswer.server.enterprise_settings.store import store_analytics_script
from ee.danswer.server.enterprise_settings.store import (
store_settings as store_ee_settings,
@@ -44,6 +46,13 @@ logger = setup_logger()
_SEED_CONFIG_ENV_VAR_NAME = "ENV_SEED_CONFIGURATION"
class NavigationItemSeed(BaseModel):
link: str
title: str
# NOTE: SVG at this path must not have a width / height specified
svg_path: str
class SeedConfiguration(BaseModel):
llms: list[LLMProviderUpsertRequest] | None = None
admin_user_emails: list[str] | None = None
@@ -51,6 +60,10 @@ class SeedConfiguration(BaseModel):
personas: list[CreatePersonaRequest] | None = None
settings: Settings | None = None
enterprise_settings: EnterpriseSettings | None = None
# allows for specifying custom navigation items that have your own custom SVG logos
nav_item_overrides: list[NavigationItemSeed] | None = None
# Use existing `CUSTOM_ANALYTICS_SECRET_KEY` for reference
analytics_script_path: str | None = None
custom_tools: List[CustomToolSeed] | None = None
@@ -60,7 +73,7 @@ def _parse_env() -> SeedConfiguration | None:
seed_config_str = os.getenv(_SEED_CONFIG_ENV_VAR_NAME)
if not seed_config_str:
return None
seed_config = SeedConfiguration.parse_raw(seed_config_str)
seed_config = SeedConfiguration.model_validate_json(seed_config_str)
return seed_config
@@ -152,9 +165,35 @@ def _seed_settings(settings: Settings) -> None:
def _seed_enterprise_settings(seed_config: SeedConfiguration) -> None:
if seed_config.enterprise_settings is not None:
if (
seed_config.enterprise_settings is not None
or seed_config.nav_item_overrides is not None
):
final_enterprise_settings = (
deepcopy(seed_config.enterprise_settings)
if seed_config.enterprise_settings
else EnterpriseSettings()
)
final_nav_items = final_enterprise_settings.custom_nav_items
if seed_config.nav_item_overrides is not None:
final_nav_items = []
for item in seed_config.nav_item_overrides:
with open(item.svg_path, "r") as file:
svg_content = file.read().strip()
final_nav_items.append(
NavigationItem(
link=item.link,
title=item.title,
svg_logo=svg_content,
)
)
final_enterprise_settings.custom_nav_items = final_nav_items
logger.notice("Seeding enterprise settings")
store_ee_settings(seed_config.enterprise_settings)
store_ee_settings(final_enterprise_settings)
def _seed_logo(db_session: Session, logo_path: str | None) -> None:

View File

@@ -1,4 +1,8 @@
[pytest]
pythonpath = .
markers =
slow: marks tests as slow
slow: marks tests as slow
filterwarnings =
ignore::DeprecationWarning
ignore::cryptography.utils.CryptographyDeprecationWarning

View File

@@ -4,7 +4,7 @@ asyncpg==0.27.0
atlassian-python-api==3.37.0
beautifulsoup4==4.12.2
boto3==1.34.84
celery==5.3.4
celery==5.5.0b4
chardet==5.2.0
dask==2023.8.1
ddtrace==2.6.5

View File

@@ -18,7 +18,8 @@ def monitor_process(process_name: str, process: subprocess.Popen) -> None:
def run_jobs(exclude_indexing: bool) -> None:
cmd_worker = [
# command setup
cmd_worker_primary = [
"celery",
"-A",
"ee.danswer.background.celery.celery_app",
@@ -26,8 +27,38 @@ def run_jobs(exclude_indexing: bool) -> None:
"--pool=threads",
"--concurrency=6",
"--loglevel=INFO",
"-n",
"primary@%n",
"-Q",
"celery,vespa_metadata_sync,connector_deletion",
"celery",
]
cmd_worker_light = [
"celery",
"-A",
"ee.danswer.background.celery.celery_app",
"worker",
"--pool=threads",
"--concurrency=16",
"--loglevel=INFO",
"-n",
"light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion",
]
cmd_worker_heavy = [
"celery",
"-A",
"ee.danswer.background.celery.celery_app",
"worker",
"--pool=threads",
"--concurrency=6",
"--loglevel=INFO",
"-n",
"heavy@%n",
"-Q",
"connector_pruning",
]
cmd_beat = [
@@ -38,19 +69,38 @@ def run_jobs(exclude_indexing: bool) -> None:
"--loglevel=INFO",
]
worker_process = subprocess.Popen(
cmd_worker, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
# spawn processes
worker_primary_process = subprocess.Popen(
cmd_worker_primary, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
worker_light_process = subprocess.Popen(
cmd_worker_light, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
worker_heavy_process = subprocess.Popen(
cmd_worker_heavy, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
beat_process = subprocess.Popen(
cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
worker_thread = threading.Thread(
target=monitor_process, args=("WORKER", worker_process)
# monitor threads
worker_primary_thread = threading.Thread(
target=monitor_process, args=("PRIMARY", worker_primary_process)
)
worker_light_thread = threading.Thread(
target=monitor_process, args=("LIGHT", worker_light_process)
)
worker_heavy_thread = threading.Thread(
target=monitor_process, args=("HEAVY", worker_heavy_process)
)
beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process))
worker_thread.start()
worker_primary_thread.start()
worker_light_thread.start()
worker_heavy_thread.start()
beat_thread.start()
if not exclude_indexing:
@@ -93,7 +143,9 @@ def run_jobs(exclude_indexing: bool) -> None:
except Exception:
pass
worker_thread.join()
worker_primary_thread.join()
worker_light_thread.join()
worker_heavy_thread.join()
beat_thread.join()

View File

@@ -24,23 +24,59 @@ autorestart=true
# on a system, but this should be okay for now since all our celery tasks are
# relatively compute-light (e.g. they tend to just make a bunch of requests to
# Vespa / Postgres)
[program:celery_worker]
[program:celery_worker_primary]
command=celery -A danswer.background.celery.celery_run:celery_app worker
--pool=threads
--concurrency=6
--concurrency=4
--prefetch-multiplier=1
--loglevel=INFO
--logfile=/var/log/celery_worker_supervisor.log
-Q celery,vespa_metadata_sync,connector_deletion
environment=LOG_FILE_NAME=celery_worker
--hostname=primary@%%n
-Q celery
stdout_logfile=/var/log/celery_worker_primary.log
stdout_logfile_maxbytes=16MB
redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
[program:celery_worker_light]
command=bash -c "celery -A danswer.background.celery.celery_run:celery_app worker \
--pool=threads \
--concurrency=${CELERY_WORKER_LIGHT_CONCURRENCY:-24} \
--prefetch-multiplier=${CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER:-8} \
--loglevel=INFO \
--hostname=light@%%n \
-Q vespa_metadata_sync,connector_deletion"
stdout_logfile=/var/log/celery_worker_light.log
stdout_logfile_maxbytes=16MB
redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
[program:celery_worker_heavy]
command=celery -A danswer.background.celery.celery_run:celery_app worker
--pool=threads
--concurrency=4
--prefetch-multiplier=1
--loglevel=INFO
--hostname=heavy@%%n
-Q connector_pruning
stdout_logfile=/var/log/celery_worker_heavy.log
stdout_logfile_maxbytes=16MB
redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
# Job scheduler for periodic tasks
[program:celery_beat]
command=celery -A danswer.background.celery.celery_run:celery_app beat
--logfile=/var/log/celery_beat_supervisor.log
environment=LOG_FILE_NAME=celery_beat
command=celery -A danswer.background.celery.celery_run:celery_app beat
stdout_logfile=/var/log/celery_beat.log
stdout_logfile_maxbytes=16MB
redirect_stderr=true
startsecs=10
stopasgroup=true
# Listens for Slack messages and responds with answers
# for all channels that the DanswerBot has been added to.
@@ -58,13 +94,12 @@ startsecs=60
# No log rotation here, since it's stdout it's handled by the Docker container logging
[program:log-redirect-handler]
command=tail -qF
/var/log/celery_beat.log
/var/log/celery_worker_primary.log
/var/log/celery_worker_light.log
/var/log/celery_worker_heavy.log
/var/log/document_indexing_info.log
/var/log/celery_beat_supervisor.log
/var/log/celery_worker_supervisor.log
/var/log/celery_beat_debug.log
/var/log/celery_worker_debug.log
/var/log/slack_bot_debug.log
stdout_logfile=/dev/stdout
stdout_logfile_maxbytes=0
redirect_stderr=true
autorestart=true
stdout_logfile_maxbytes = 0 # must be set to 0 when stdout_logfile=/dev/stdout
autorestart=true

View File

@@ -0,0 +1,48 @@
import os
import time
import pytest
from danswer.configs.constants import DocumentSource
from danswer.connectors.danswer_jira.connector import JiraConnector
@pytest.fixture
def jira_connector() -> JiraConnector:
connector = JiraConnector(
"https://danswerai.atlassian.net/jira/software/c/projects/AS/boards/6",
comment_email_blacklist=[],
)
connector.load_credentials(
{
"jira_user_email": os.environ["JIRA_USER_EMAIL"],
"jira_api_token": os.environ["JIRA_API_TOKEN"],
}
)
return connector
def test_jira_connector_basic(jira_connector: JiraConnector) -> None:
doc_batch_generator = jira_connector.poll_source(0, time.time())
doc_batch = next(doc_batch_generator)
with pytest.raises(StopIteration):
next(doc_batch_generator)
assert len(doc_batch) == 1
doc = doc_batch[0]
assert doc.id == "https://danswerai.atlassian.net/browse/AS-2"
assert doc.semantic_identifier == "test123small"
assert doc.source == DocumentSource.JIRA
assert doc.metadata == {"priority": "Medium", "status": "Backlog"}
assert doc.secondary_owners is None
assert doc.title is None
assert doc.from_ingestion_api is False
assert doc.additional_info is None
assert len(doc.sections) == 1
section = doc.sections[0]
assert section.text == "example_text\n"
assert section.link == "https://danswerai.atlassian.net/browse/AS-2"

View File

@@ -72,6 +72,7 @@ COPY ./danswer /app/danswer
COPY ./shared_configs /app/shared_configs
COPY ./alembic /app/alembic
COPY ./alembic.ini /app/alembic.ini
COPY ./pytest.ini /app/pytest.ini
COPY supervisord.conf /usr/etc/supervisord.conf
# Integration test stuff

View File

@@ -6,8 +6,8 @@ from danswer.db.models import UserRole
from ee.danswer.server.api_key.models import APIKeyArgs
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import TestAPIKey
from tests.integration.common_utils.test_models import TestUser
from tests.integration.common_utils.test_models import DATestAPIKey
from tests.integration.common_utils.test_models import DATestUser
class APIKeyManager:
@@ -15,8 +15,8 @@ class APIKeyManager:
def create(
name: str | None = None,
api_key_role: UserRole = UserRole.ADMIN,
user_performing_action: TestUser | None = None,
) -> TestAPIKey:
user_performing_action: DATestUser | None = None,
) -> DATestAPIKey:
name = f"{name}-api-key" if name else f"test-api-key-{uuid4()}"
api_key_request = APIKeyArgs(
name=name,
@@ -31,7 +31,7 @@ class APIKeyManager:
)
api_key_response.raise_for_status()
api_key = api_key_response.json()
result_api_key = TestAPIKey(
result_api_key = DATestAPIKey(
api_key_id=api_key["api_key_id"],
api_key_display=api_key["api_key_display"],
api_key=api_key["api_key"],
@@ -45,8 +45,8 @@ class APIKeyManager:
@staticmethod
def delete(
api_key: TestAPIKey,
user_performing_action: TestUser | None = None,
api_key: DATestAPIKey,
user_performing_action: DATestUser | None = None,
) -> None:
api_key_response = requests.delete(
f"{API_SERVER_URL}/admin/api-key/{api_key.api_key_id}",
@@ -58,8 +58,8 @@ class APIKeyManager:
@staticmethod
def get_all(
user_performing_action: TestUser | None = None,
) -> list[TestAPIKey]:
user_performing_action: DATestUser | None = None,
) -> list[DATestAPIKey]:
api_key_response = requests.get(
f"{API_SERVER_URL}/admin/api-key",
headers=user_performing_action.headers
@@ -67,13 +67,13 @@ class APIKeyManager:
else GENERAL_HEADERS,
)
api_key_response.raise_for_status()
return [TestAPIKey(**api_key) for api_key in api_key_response.json()]
return [DATestAPIKey(**api_key) for api_key in api_key_response.json()]
@staticmethod
def verify(
api_key: TestAPIKey,
api_key: DATestAPIKey,
verify_deleted: bool = False,
user_performing_action: TestUser | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
retrieved_keys = APIKeyManager.get_all(
user_performing_action=user_performing_action

View File

@@ -1,4 +1,5 @@
import time
from datetime import datetime
from typing import Any
from uuid import uuid4
@@ -7,6 +8,8 @@ import requests
from danswer.connectors.models import InputType
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import TaskStatus
from danswer.server.documents.models import CCPairPruningTask
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.server.documents.models import ConnectorIndexingStatus
from danswer.server.documents.models import DocumentSource
@@ -15,8 +18,8 @@ from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.managers.connector import ConnectorManager
from tests.integration.common_utils.managers.credential import CredentialManager
from tests.integration.common_utils.test_models import TestCCPair
from tests.integration.common_utils.test_models import TestUser
from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestUser
def _cc_pair_creator(
@@ -25,8 +28,8 @@ def _cc_pair_creator(
name: str | None = None,
access_type: AccessType = AccessType.PUBLIC,
groups: list[int] | None = None,
user_performing_action: TestUser | None = None,
) -> TestCCPair:
user_performing_action: DATestUser | None = None,
) -> DATestCCPair:
name = f"{name}-cc-pair" if name else f"test-cc-pair-{uuid4()}"
request = {
@@ -43,7 +46,7 @@ def _cc_pair_creator(
else GENERAL_HEADERS,
)
response.raise_for_status()
return TestCCPair(
return DATestCCPair(
id=response.json()["data"],
name=name,
connector_id=connector_id,
@@ -63,8 +66,8 @@ class CCPairManager:
input_type: InputType = InputType.LOAD_STATE,
connector_specific_config: dict[str, Any] | None = None,
credential_json: dict[str, Any] | None = None,
user_performing_action: TestUser | None = None,
) -> TestCCPair:
user_performing_action: DATestUser | None = None,
) -> DATestCCPair:
connector = ConnectorManager.create(
name=name,
source=source,
@@ -98,8 +101,8 @@ class CCPairManager:
name: str | None = None,
access_type: AccessType = AccessType.PUBLIC,
groups: list[int] | None = None,
user_performing_action: TestUser | None = None,
) -> TestCCPair:
user_performing_action: DATestUser | None = None,
) -> DATestCCPair:
return _cc_pair_creator(
connector_id=connector_id,
credential_id=credential_id,
@@ -111,8 +114,8 @@ class CCPairManager:
@staticmethod
def pause_cc_pair(
cc_pair: TestCCPair,
user_performing_action: TestUser | None = None,
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
) -> None:
result = requests.put(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/status",
@@ -125,8 +128,8 @@ class CCPairManager:
@staticmethod
def delete(
cc_pair: TestCCPair,
user_performing_action: TestUser | None = None,
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
) -> None:
cc_pair_identifier = ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id,
@@ -141,9 +144,28 @@ class CCPairManager:
)
result.raise_for_status()
@staticmethod
def get_one(
cc_pair_id: int,
user_performing_action: DATestUser | None = None,
) -> ConnectorIndexingStatus | None:
response = requests.get(
f"{API_SERVER_URL}/manage/admin/connector/indexing-status",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
for cc_pair_json in response.json():
cc_pair = ConnectorIndexingStatus(**cc_pair_json)
if cc_pair.cc_pair_id == cc_pair_id:
return cc_pair
return None
@staticmethod
def get_all(
user_performing_action: TestUser | None = None,
user_performing_action: DATestUser | None = None,
) -> list[ConnectorIndexingStatus]:
response = requests.get(
f"{API_SERVER_URL}/manage/admin/connector/indexing-status",
@@ -156,9 +178,9 @@ class CCPairManager:
@staticmethod
def verify(
cc_pair: TestCCPair,
cc_pair: DATestCCPair,
verify_deleted: bool = False,
user_performing_action: TestUser | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
all_cc_pairs = CCPairManager.get_all(user_performing_action)
for retrieved_cc_pair in all_cc_pairs:
@@ -182,10 +204,99 @@ class CCPairManager:
raise ValueError(f"CC pair {cc_pair.id} not found")
@staticmethod
def wait_for_deletion_completion(
user_performing_action: TestUser | None = None,
def wait_for_indexing(
cc_pair_test: DATestCCPair,
after: datetime,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
start = time.time()
"""after: Wait for an indexing success time after this time"""
start = time.monotonic()
while True:
cc_pairs = CCPairManager.get_all(user_performing_action)
for cc_pair in cc_pairs:
if cc_pair.cc_pair_id != cc_pair_test.id:
continue
if cc_pair.last_success and cc_pair.last_success > after:
print(f"cc_pair {cc_pair_test.id} indexing complete.")
return
elapsed = time.monotonic() - start
if elapsed > timeout:
raise TimeoutError(
f"CC pair indexing was not completed within {timeout} seconds"
)
print(
f"Waiting for CC indexing to complete. elapsed={elapsed:.2f} timeout={timeout}"
)
time.sleep(5)
@staticmethod
def prune(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
) -> None:
result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/prune",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
result.raise_for_status()
@staticmethod
def get_prune_task(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
) -> CCPairPruningTask:
response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/prune",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return CCPairPruningTask(**response.json())
@staticmethod
def wait_for_prune(
cc_pair_test: DATestCCPair,
after: datetime,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""after: The task register time must be after this time."""
start = time.monotonic()
while True:
task = CCPairManager.get_prune_task(cc_pair_test, user_performing_action)
if not task:
raise ValueError("Prune task not found.")
if not task.register_time or task.register_time < after:
raise ValueError("Prune task register time is too early.")
if task.status == TaskStatus.SUCCESS:
# Pruning succeeded
return
elapsed = time.monotonic() - start
if elapsed > timeout:
raise TimeoutError(
f"CC pair pruning was not completed within {timeout} seconds"
)
print(
f"Waiting for CC pruning to complete. elapsed={elapsed:.2f} timeout={timeout}"
)
time.sleep(5)
@staticmethod
def wait_for_deletion_completion(
user_performing_action: DATestUser | None = None,
) -> None:
start = time.monotonic()
while True:
cc_pairs = CCPairManager.get_all(user_performing_action)
if all(
@@ -194,7 +305,7 @@ class CCPairManager:
):
return
if time.time() - start > MAX_DELAY:
if time.monotonic() - start > MAX_DELAY:
raise TimeoutError(
f"CC pairs deletion was not completed within the {MAX_DELAY} seconds"
)

View File

@@ -13,10 +13,10 @@ from danswer.server.query_and_chat.models import ChatSessionCreationRequest
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestChatMessage
from tests.integration.common_utils.test_models import DATestChatSession
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.test_models import StreamedResponse
from tests.integration.common_utils.test_models import TestChatMessage
from tests.integration.common_utils.test_models import TestChatSession
from tests.integration.common_utils.test_models import TestUser
class ChatSessionManager:
@@ -24,8 +24,8 @@ class ChatSessionManager:
def create(
persona_id: int = -1,
description: str = "Test chat session",
user_performing_action: TestUser | None = None,
) -> TestChatSession:
user_performing_action: DATestUser | None = None,
) -> DATestChatSession:
chat_session_creation_req = ChatSessionCreationRequest(
persona_id=persona_id, description=description
)
@@ -38,7 +38,7 @@ class ChatSessionManager:
)
response.raise_for_status()
chat_session_id = response.json()["chat_session_id"]
return TestChatSession(
return DATestChatSession(
id=chat_session_id, persona_id=persona_id, description=description
)
@@ -47,7 +47,7 @@ class ChatSessionManager:
chat_session_id: int,
message: str,
parent_message_id: int | None = None,
user_performing_action: TestUser | None = None,
user_performing_action: DATestUser | None = None,
file_descriptors: list[FileDescriptor] = [],
prompt_id: int | None = None,
search_doc_ids: list[int] | None = None,
@@ -90,7 +90,7 @@ class ChatSessionManager:
def get_answer_with_quote(
persona_id: int,
message: str,
user_performing_action: TestUser | None = None,
user_performing_action: DATestUser | None = None,
) -> StreamedResponse:
direct_qa_request = DirectQARequest(
messages=[ThreadMessage(message=message)],
@@ -137,9 +137,9 @@ class ChatSessionManager:
@staticmethod
def get_chat_history(
chat_session: TestChatSession,
user_performing_action: TestUser | None = None,
) -> list[TestChatMessage]:
chat_session: DATestChatSession,
user_performing_action: DATestUser | None = None,
) -> list[DATestChatMessage]:
response = requests.get(
f"{API_SERVER_URL}/chat/history/{chat_session.id}",
headers=user_performing_action.headers
@@ -149,7 +149,7 @@ class ChatSessionManager:
response.raise_for_status()
return [
TestChatMessage(
DATestChatMessage(
id=msg["id"],
chat_session_id=chat_session.id,
parent_message_id=msg.get("parent_message_id"),

View File

@@ -8,8 +8,8 @@ from danswer.server.documents.models import ConnectorUpdateRequest
from danswer.server.documents.models import DocumentSource
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import TestConnector
from tests.integration.common_utils.test_models import TestUser
from tests.integration.common_utils.test_models import DATestConnector
from tests.integration.common_utils.test_models import DATestUser
class ConnectorManager:
@@ -21,8 +21,8 @@ class ConnectorManager:
connector_specific_config: dict[str, Any] | None = None,
is_public: bool = True,
groups: list[int] | None = None,
user_performing_action: TestUser | None = None,
) -> TestConnector:
user_performing_action: DATestUser | None = None,
) -> DATestConnector:
name = f"{name}-connector" if name else f"test-connector-{uuid4()}"
connector_update_request = ConnectorUpdateRequest(
@@ -44,7 +44,7 @@ class ConnectorManager:
response.raise_for_status()
response_data = response.json()
return TestConnector(
return DATestConnector(
id=response_data.get("id"),
name=name,
source=source,
@@ -56,8 +56,8 @@ class ConnectorManager:
@staticmethod
def edit(
connector: TestConnector,
user_performing_action: TestUser | None = None,
connector: DATestConnector,
user_performing_action: DATestUser | None = None,
) -> None:
response = requests.patch(
url=f"{API_SERVER_URL}/manage/admin/connector/{connector.id}",
@@ -70,8 +70,8 @@ class ConnectorManager:
@staticmethod
def delete(
connector: TestConnector,
user_performing_action: TestUser | None = None,
connector: DATestConnector,
user_performing_action: DATestUser | None = None,
) -> None:
response = requests.delete(
url=f"{API_SERVER_URL}/manage/admin/connector/{connector.id}",
@@ -83,8 +83,8 @@ class ConnectorManager:
@staticmethod
def get_all(
user_performing_action: TestUser | None = None,
) -> list[TestConnector]:
user_performing_action: DATestUser | None = None,
) -> list[DATestConnector]:
response = requests.get(
url=f"{API_SERVER_URL}/manage/connector",
headers=user_performing_action.headers
@@ -93,7 +93,7 @@ class ConnectorManager:
)
response.raise_for_status()
return [
TestConnector(
DATestConnector(
id=conn.get("id"),
name=conn.get("name", ""),
source=conn.get("source", DocumentSource.FILE),
@@ -105,8 +105,8 @@ class ConnectorManager:
@staticmethod
def get(
connector_id: int, user_performing_action: TestUser | None = None
) -> TestConnector:
connector_id: int, user_performing_action: DATestUser | None = None
) -> DATestConnector:
response = requests.get(
url=f"{API_SERVER_URL}/manage/connector/{connector_id}",
headers=user_performing_action.headers
@@ -115,7 +115,7 @@ class ConnectorManager:
)
response.raise_for_status()
conn = response.json()
return TestConnector(
return DATestConnector(
id=conn.get("id"),
name=conn.get("name", ""),
source=conn.get("source", DocumentSource.FILE),

View File

@@ -7,8 +7,8 @@ from danswer.server.documents.models import CredentialSnapshot
from danswer.server.documents.models import DocumentSource
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import TestCredential
from tests.integration.common_utils.test_models import TestUser
from tests.integration.common_utils.test_models import DATestCredential
from tests.integration.common_utils.test_models import DATestUser
class CredentialManager:
@@ -20,8 +20,8 @@ class CredentialManager:
source: DocumentSource = DocumentSource.FILE,
curator_public: bool = True,
groups: list[int] | None = None,
user_performing_action: TestUser | None = None,
) -> TestCredential:
user_performing_action: DATestUser | None = None,
) -> DATestCredential:
name = f"{name}-credential" if name else f"test-credential-{uuid4()}"
credential_request = {
@@ -41,7 +41,7 @@ class CredentialManager:
)
response.raise_for_status()
return TestCredential(
return DATestCredential(
id=response.json()["id"],
name=name,
credential_json=credential_json or {},
@@ -53,8 +53,8 @@ class CredentialManager:
@staticmethod
def edit(
credential: TestCredential,
user_performing_action: TestUser | None = None,
credential: DATestCredential,
user_performing_action: DATestUser | None = None,
) -> None:
request = credential.model_dump(include={"name", "credential_json"})
response = requests.put(
@@ -68,8 +68,8 @@ class CredentialManager:
@staticmethod
def delete(
credential: TestCredential,
user_performing_action: TestUser | None = None,
credential: DATestCredential,
user_performing_action: DATestUser | None = None,
) -> None:
response = requests.delete(
url=f"{API_SERVER_URL}/manage/credential/{credential.id}",
@@ -81,7 +81,7 @@ class CredentialManager:
@staticmethod
def get(
credential_id: int, user_performing_action: TestUser | None = None
credential_id: int, user_performing_action: DATestUser | None = None
) -> CredentialSnapshot:
response = requests.get(
url=f"{API_SERVER_URL}/manage/credential/{credential_id}",
@@ -94,7 +94,7 @@ class CredentialManager:
@staticmethod
def get_all(
user_performing_action: TestUser | None = None,
user_performing_action: DATestUser | None = None,
) -> list[CredentialSnapshot]:
response = requests.get(
f"{API_SERVER_URL}/manage/credential",
@@ -107,9 +107,9 @@ class CredentialManager:
@staticmethod
def verify(
credential: TestCredential,
credential: DATestCredential,
verify_deleted: bool = False,
user_performing_action: TestUser | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
all_credentials = CredentialManager.get_all(user_performing_action)
for fetched_credential in all_credentials:

View File

@@ -7,19 +7,19 @@ from danswer.db.enums import AccessType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import NUM_DOCS
from tests.integration.common_utils.managers.api_key import TestAPIKey
from tests.integration.common_utils.managers.cc_pair import TestCCPair
from tests.integration.common_utils.managers.api_key import DATestAPIKey
from tests.integration.common_utils.managers.cc_pair import DATestCCPair
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.test_models import SimpleTestDocument
from tests.integration.common_utils.test_models import TestUser
from tests.integration.common_utils.vespa import TestVespaClient
from tests.integration.common_utils.vespa import vespa_fixture
def _verify_document_permissions(
retrieved_doc: dict,
cc_pair: TestCCPair,
cc_pair: DATestCCPair,
doc_set_names: list[str] | None = None,
group_names: list[str] | None = None,
doc_creating_user: TestUser | None = None,
doc_creating_user: DATestUser | None = None,
) -> None:
acl_keys = set(retrieved_doc["access_control_list"].keys())
print(f"ACL keys: {acl_keys}")
@@ -83,10 +83,10 @@ def _generate_dummy_document(
class DocumentManager:
@staticmethod
def seed_dummy_docs(
cc_pair: TestCCPair,
cc_pair: DATestCCPair,
num_docs: int = NUM_DOCS,
document_ids: list[str] | None = None,
api_key: TestAPIKey | None = None,
api_key: DATestAPIKey | None = None,
) -> list[SimpleTestDocument]:
# Use provided document_ids if available, otherwise generate random UUIDs
if document_ids is None:
@@ -116,10 +116,10 @@ class DocumentManager:
@staticmethod
def seed_doc_with_content(
cc_pair: TestCCPair,
cc_pair: DATestCCPair,
content: str,
document_id: str | None = None,
api_key: TestAPIKey | None = None,
api_key: DATestAPIKey | None = None,
) -> SimpleTestDocument:
# Use provided document_ids if available, otherwise generate random UUIDs
if document_id is None:
@@ -142,13 +142,13 @@ class DocumentManager:
@staticmethod
def verify(
vespa_client: TestVespaClient,
cc_pair: TestCCPair,
vespa_client: vespa_fixture,
cc_pair: DATestCCPair,
# If None, will not check doc sets or groups
# If empty list, will check for empty doc sets or groups
doc_set_names: list[str] | None = None,
group_names: list[str] | None = None,
doc_creating_user: TestUser | None = None,
doc_creating_user: DATestUser | None = None,
verify_deleted: bool = False,
) -> None:
doc_ids = [document.id for document in cc_pair.documents]

View File

@@ -6,8 +6,8 @@ import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.test_models import TestDocumentSet
from tests.integration.common_utils.test_models import TestUser
from tests.integration.common_utils.test_models import DATestDocumentSet
from tests.integration.common_utils.test_models import DATestUser
class DocumentSetManager:
@@ -19,8 +19,8 @@ class DocumentSetManager:
is_public: bool = True,
users: list[str] | None = None,
groups: list[int] | None = None,
user_performing_action: TestUser | None = None,
) -> TestDocumentSet:
user_performing_action: DATestUser | None = None,
) -> DATestDocumentSet:
if name is None:
name = f"test_doc_set_{str(uuid4())}"
@@ -42,7 +42,7 @@ class DocumentSetManager:
)
response.raise_for_status()
return TestDocumentSet(
return DATestDocumentSet(
id=int(response.json()),
name=name,
description=description or name,
@@ -55,8 +55,8 @@ class DocumentSetManager:
@staticmethod
def edit(
document_set: TestDocumentSet,
user_performing_action: TestUser | None = None,
document_set: DATestDocumentSet,
user_performing_action: DATestUser | None = None,
) -> bool:
doc_set_update_request = {
"id": document_set.id,
@@ -78,8 +78,8 @@ class DocumentSetManager:
@staticmethod
def delete(
document_set: TestDocumentSet,
user_performing_action: TestUser | None = None,
document_set: DATestDocumentSet,
user_performing_action: DATestUser | None = None,
) -> bool:
response = requests.delete(
f"{API_SERVER_URL}/manage/admin/document-set/{document_set.id}",
@@ -92,8 +92,8 @@ class DocumentSetManager:
@staticmethod
def get_all(
user_performing_action: TestUser | None = None,
) -> list[TestDocumentSet]:
user_performing_action: DATestUser | None = None,
) -> list[DATestDocumentSet]:
response = requests.get(
f"{API_SERVER_URL}/manage/document-set",
headers=user_performing_action.headers
@@ -102,7 +102,7 @@ class DocumentSetManager:
)
response.raise_for_status()
return [
TestDocumentSet(
DATestDocumentSet(
id=doc_set["id"],
name=doc_set["name"],
description=doc_set["description"],
@@ -119,8 +119,8 @@ class DocumentSetManager:
@staticmethod
def wait_for_sync(
document_sets_to_check: list[TestDocumentSet] | None = None,
user_performing_action: TestUser | None = None,
document_sets_to_check: list[DATestDocumentSet] | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
# wait for document sets to be synced
start = time.time()
@@ -148,9 +148,9 @@ class DocumentSetManager:
@staticmethod
def verify(
document_set: TestDocumentSet,
document_set: DATestDocumentSet,
verify_deleted: bool = False,
user_performing_action: TestUser | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
doc_sets = DocumentSetManager.get_all(user_performing_action)
for doc_set in doc_sets:

View File

@@ -3,11 +3,12 @@ from uuid import uuid4
import requests
from danswer.server.manage.llm.models import FullLLMProvider
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import TestLLMProvider
from tests.integration.common_utils.test_models import TestUser
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
class LLMProviderManager:
@@ -21,8 +22,8 @@ class LLMProviderManager:
api_version: str | None = None,
groups: list[int] | None = None,
is_public: bool | None = None,
user_performing_action: TestUser | None = None,
) -> TestLLMProvider:
user_performing_action: DATestUser | None = None,
) -> DATestLLMProvider:
print("Seeding LLM Providers...")
llm_provider = LLMProviderUpsertRequest(
@@ -49,7 +50,10 @@ class LLMProviderManager:
)
llm_response.raise_for_status()
response_data = llm_response.json()
result_llm = TestLLMProvider(
import json
print(json.dumps(response_data, indent=4))
result_llm = DATestLLMProvider(
id=response_data["id"],
name=response_data["name"],
provider=response_data["provider"],
@@ -73,11 +77,9 @@ class LLMProviderManager:
@staticmethod
def delete(
llm_provider: TestLLMProvider,
user_performing_action: TestUser | None = None,
llm_provider: DATestLLMProvider,
user_performing_action: DATestUser | None = None,
) -> bool:
if not llm_provider.id:
raise ValueError("LLM Provider ID is required to delete a provider")
response = requests.delete(
f"{API_SERVER_URL}/admin/llm/provider/{llm_provider.id}",
headers=user_performing_action.headers
@@ -86,3 +88,43 @@ class LLMProviderManager:
)
response.raise_for_status()
return True
@staticmethod
def get_all(
user_performing_action: DATestUser | None = None,
) -> list[FullLLMProvider]:
response = requests.get(
f"{API_SERVER_URL}/admin/llm/provider",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return [FullLLMProvider(**ug) for ug in response.json()]
@staticmethod
def verify(
llm_provider: DATestLLMProvider,
verify_deleted: bool = False,
user_performing_action: DATestUser | None = None,
) -> None:
all_llm_providers = LLMProviderManager.get_all(user_performing_action)
for fetched_llm_provider in all_llm_providers:
if llm_provider.id == fetched_llm_provider.id:
if verify_deleted:
raise ValueError(
f"User group {llm_provider.id} found but should be deleted"
)
fetched_llm_groups = set(fetched_llm_provider.groups)
llm_provider_groups = set(llm_provider.groups)
if (
fetched_llm_groups == llm_provider_groups
and llm_provider.provider == fetched_llm_provider.provider
and llm_provider.api_key == fetched_llm_provider.api_key
and llm_provider.default_model_name
== fetched_llm_provider.default_model_name
and llm_provider.is_public == fetched_llm_provider.is_public
):
return
if not verify_deleted:
raise ValueError(f"User group {llm_provider.id} not found")

View File

@@ -6,8 +6,8 @@ from danswer.search.enums import RecencyBiasSetting
from danswer.server.features.persona.models import PersonaSnapshot
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import TestPersona
from tests.integration.common_utils.test_models import TestUser
from tests.integration.common_utils.test_models import DATestPersona
from tests.integration.common_utils.test_models import DATestUser
class PersonaManager:
@@ -27,8 +27,8 @@ class PersonaManager:
llm_model_version_override: str | None = None,
users: list[str] | None = None,
groups: list[int] | None = None,
user_performing_action: TestUser | None = None,
) -> TestPersona:
user_performing_action: DATestUser | None = None,
) -> DATestPersona:
name = name or f"test-persona-{uuid4()}"
description = description or f"Description for {name}"
@@ -59,7 +59,7 @@ class PersonaManager:
response.raise_for_status()
persona_data = response.json()
return TestPersona(
return DATestPersona(
id=persona_data["id"],
name=name,
description=description,
@@ -79,7 +79,7 @@ class PersonaManager:
@staticmethod
def edit(
persona: TestPersona,
persona: DATestPersona,
name: str | None = None,
description: str | None = None,
num_chunks: float | None = None,
@@ -94,8 +94,8 @@ class PersonaManager:
llm_model_version_override: str | None = None,
users: list[str] | None = None,
groups: list[int] | None = None,
user_performing_action: TestUser | None = None,
) -> TestPersona:
user_performing_action: DATestUser | None = None,
) -> DATestPersona:
persona_update_request = {
"name": name or persona.name,
"description": description or persona.description,
@@ -127,7 +127,7 @@ class PersonaManager:
response.raise_for_status()
updated_persona_data = response.json()
return TestPersona(
return DATestPersona(
id=updated_persona_data["id"],
name=updated_persona_data["name"],
description=updated_persona_data["description"],
@@ -151,7 +151,7 @@ class PersonaManager:
@staticmethod
def get_all(
user_performing_action: TestUser | None = None,
user_performing_action: DATestUser | None = None,
) -> list[PersonaSnapshot]:
response = requests.get(
f"{API_SERVER_URL}/admin/persona",
@@ -164,38 +164,46 @@ class PersonaManager:
@staticmethod
def verify(
test_persona: TestPersona,
user_performing_action: TestUser | None = None,
persona: DATestPersona,
user_performing_action: DATestUser | None = None,
) -> bool:
all_personas = PersonaManager.get_all(user_performing_action)
for persona in all_personas:
if persona.id == test_persona.id:
for fetched_persona in all_personas:
if fetched_persona.id == persona.id:
return (
persona.name == test_persona.name
and persona.description == test_persona.description
and persona.num_chunks == test_persona.num_chunks
and persona.llm_relevance_filter
== test_persona.llm_relevance_filter
and persona.is_public == test_persona.is_public
and persona.llm_filter_extraction
== test_persona.llm_filter_extraction
and persona.llm_model_provider_override
== test_persona.llm_model_provider_override
and persona.llm_model_version_override
== test_persona.llm_model_version_override
and set(persona.prompts) == set(test_persona.prompt_ids)
and set(persona.document_sets) == set(test_persona.document_set_ids)
and set(persona.tools) == set(test_persona.tool_ids)
and set(user.email for user in persona.users)
== set(test_persona.users)
and set(persona.groups) == set(test_persona.groups)
fetched_persona.name == persona.name
and fetched_persona.description == persona.description
and fetched_persona.num_chunks == persona.num_chunks
and fetched_persona.llm_relevance_filter
== persona.llm_relevance_filter
and fetched_persona.is_public == persona.is_public
and fetched_persona.llm_filter_extraction
== persona.llm_filter_extraction
and fetched_persona.llm_model_provider_override
== persona.llm_model_provider_override
and fetched_persona.llm_model_version_override
== persona.llm_model_version_override
and set([prompt.id for prompt in fetched_persona.prompts])
== set(persona.prompt_ids)
and set(
[
document_set.id
for document_set in fetched_persona.document_sets
]
)
== set(persona.document_set_ids)
and set([tool.id for tool in fetched_persona.tools])
== set(persona.tool_ids)
and set(user.email for user in fetched_persona.users)
== set(persona.users)
and set(fetched_persona.groups) == set(persona.groups)
)
return False
@staticmethod
def delete(
persona: TestPersona,
user_performing_action: TestUser | None = None,
persona: DATestPersona,
user_performing_action: DATestUser | None = None,
) -> bool:
response = requests.delete(
f"{API_SERVER_URL}/persona/{persona.id}",

View File

@@ -10,14 +10,14 @@ from danswer.server.models import FullUserSnapshot
from danswer.server.models import InvitedUserSnapshot
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import TestUser
from tests.integration.common_utils.test_models import DATestUser
class UserManager:
@staticmethod
def create(
name: str | None = None,
) -> TestUser:
) -> DATestUser:
if name is None:
name = f"test{str(uuid4())}"
@@ -36,7 +36,7 @@ class UserManager:
)
response.raise_for_status()
test_user = TestUser(
test_user = DATestUser(
id=response.json()["id"],
email=email,
password=password,
@@ -49,7 +49,7 @@ class UserManager:
return test_user
@staticmethod
def login_as_user(test_user: TestUser) -> str:
def login_as_user(test_user: DATestUser) -> str:
data = urlencode(
{
"username": test_user.email,
@@ -74,7 +74,7 @@ class UserManager:
return f"{result_cookie.name}={result_cookie.value}"
@staticmethod
def verify_role(user_to_verify: TestUser, target_role: UserRole) -> bool:
def verify_role(user_to_verify: DATestUser, target_role: UserRole) -> bool:
response = requests.get(
url=f"{API_SERVER_URL}/me",
headers=user_to_verify.headers,
@@ -84,9 +84,9 @@ class UserManager:
@staticmethod
def set_role(
user_to_set: TestUser,
user_to_set: DATestUser,
target_role: UserRole,
user_to_perform_action: TestUser | None = None,
user_to_perform_action: DATestUser | None = None,
) -> None:
if user_to_perform_action is None:
user_to_perform_action = user_to_set
@@ -98,7 +98,9 @@ class UserManager:
response.raise_for_status()
@staticmethod
def verify(user: TestUser, user_to_perform_action: TestUser | None = None) -> None:
def verify(
user: DATestUser, user_to_perform_action: DATestUser | None = None
) -> None:
if user_to_perform_action is None:
user_to_perform_action = user
response = requests.get(

View File

@@ -7,8 +7,8 @@ from ee.danswer.server.user_group.models import UserGroup
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.test_models import TestUser
from tests.integration.common_utils.test_models import TestUserGroup
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.test_models import DATestUserGroup
class UserGroupManager:
@@ -17,8 +17,8 @@ class UserGroupManager:
name: str | None = None,
user_ids: list[str] | None = None,
cc_pair_ids: list[int] | None = None,
user_performing_action: TestUser | None = None,
) -> TestUserGroup:
user_performing_action: DATestUser | None = None,
) -> DATestUserGroup:
name = f"{name}-user-group" if name else f"test-user-group-{uuid4()}"
request = {
@@ -34,7 +34,7 @@ class UserGroupManager:
else GENERAL_HEADERS,
)
response.raise_for_status()
test_user_group = TestUserGroup(
test_user_group = DATestUserGroup(
id=response.json()["id"],
name=response.json()["name"],
user_ids=[user["id"] for user in response.json()["users"]],
@@ -44,11 +44,9 @@ class UserGroupManager:
@staticmethod
def edit(
user_group: TestUserGroup,
user_performing_action: TestUser | None = None,
user_group: DATestUserGroup,
user_performing_action: DATestUser | None = None,
) -> None:
if not user_group.id:
raise ValueError("User group has no ID")
response = requests.patch(
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}",
json=user_group.model_dump(),
@@ -59,14 +57,25 @@ class UserGroupManager:
response.raise_for_status()
@staticmethod
def set_curator_status(
test_user_group: TestUserGroup,
user_to_set_as_curator: TestUser,
is_curator: bool = True,
user_performing_action: TestUser | None = None,
def delete(
user_group: DATestUserGroup,
user_performing_action: DATestUser | None = None,
) -> None:
response = requests.delete(
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
@staticmethod
def set_curator_status(
test_user_group: DATestUserGroup,
user_to_set_as_curator: DATestUser,
is_curator: bool = True,
user_performing_action: DATestUser | None = None,
) -> None:
if not user_to_set_as_curator.id:
raise ValueError("User has no ID")
set_curator_request = {
"user_id": user_to_set_as_curator.id,
"is_curator": is_curator,
@@ -82,7 +91,7 @@ class UserGroupManager:
@staticmethod
def get_all(
user_performing_action: TestUser | None = None,
user_performing_action: DATestUser | None = None,
) -> list[UserGroup]:
response = requests.get(
f"{API_SERVER_URL}/manage/admin/user-group",
@@ -95,9 +104,9 @@ class UserGroupManager:
@staticmethod
def verify(
user_group: TestUserGroup,
user_group: DATestUserGroup,
verify_deleted: bool = False,
user_performing_action: TestUser | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
all_user_groups = UserGroupManager.get_all(user_performing_action)
for fetched_user_group in all_user_groups:
@@ -120,8 +129,8 @@ class UserGroupManager:
@staticmethod
def wait_for_sync(
user_groups_to_check: list[TestUserGroup] | None = None,
user_performing_action: TestUser | None = None,
user_groups_to_check: list[DATestUserGroup] | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
start = time.time()
while True:
@@ -130,7 +139,7 @@ class UserGroupManager:
check_ids = {user_group.id for user_group in user_groups_to_check}
user_group_ids = {user_group.id for user_group in user_groups}
if not check_ids.issubset(user_group_ids):
raise RuntimeError("Document set not found")
raise RuntimeError("User group not found")
user_groups = [
user_group
for user_group in user_groups
@@ -146,3 +155,26 @@ class UserGroupManager:
else:
print("User groups were not synced yet, waiting...")
time.sleep(2)
@staticmethod
def wait_for_deletion_completion(
user_groups_to_check: list[DATestUserGroup],
user_performing_action: DATestUser | None = None,
) -> None:
start = time.time()
user_group_ids_to_check = {user_group.id for user_group in user_groups_to_check}
while True:
fetched_user_groups = UserGroupManager.get_all(user_performing_action)
fetched_user_group_ids = {
user_group.id for user_group in fetched_user_groups
}
if not user_group_ids_to_check.intersection(fetched_user_group_ids):
return
if time.time() - start > MAX_DELAY:
raise TimeoutError(
f"User groups deletion was not completed within the {MAX_DELAY} seconds"
)
else:
print("Some user groups are still being deleted, waiting...")
time.sleep(2)

View File

@@ -20,6 +20,9 @@ from danswer.document_index.vespa.index import VespaIndex
from danswer.indexing.models import IndexingSetting
from danswer.main import setup_postgres
from danswer.main import setup_vespa
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _run_migrations(
@@ -165,8 +168,8 @@ def reset_vespa() -> None:
def reset_all() -> None:
"""Reset both Postgres and Vespa."""
print("Resetting Postgres...")
logger.info("Resetting Postgres...")
reset_postgres()
print("Resetting Vespa...")
logger.info("Resetting Vespa...")
reset_vespa()
print("Finished resetting all.")
logger.info("Finished resetting all.")

View File

@@ -20,7 +20,7 @@ This means the flow is:
"""
class TestAPIKey(BaseModel):
class DATestAPIKey(BaseModel):
api_key_id: int
api_key_display: str
api_key: str | None = None # only present on initial creation
@@ -31,14 +31,14 @@ class TestAPIKey(BaseModel):
headers: dict
class TestUser(BaseModel):
class DATestUser(BaseModel):
id: str
email: str
password: str
headers: dict
class TestCredential(BaseModel):
class DATestCredential(BaseModel):
id: int
name: str
credential_json: dict[str, Any]
@@ -48,7 +48,7 @@ class TestCredential(BaseModel):
groups: list[int]
class TestConnector(BaseModel):
class DATestConnector(BaseModel):
id: int
name: str
source: DocumentSource
@@ -63,7 +63,7 @@ class SimpleTestDocument(BaseModel):
content: str
class TestCCPair(BaseModel):
class DATestCCPair(BaseModel):
id: int
name: str
connector_id: int
@@ -73,26 +73,26 @@ class TestCCPair(BaseModel):
documents: list[SimpleTestDocument] = Field(default_factory=list)
class TestUserGroup(BaseModel):
class DATestUserGroup(BaseModel):
id: int
name: str
user_ids: list[str]
cc_pair_ids: list[int]
class TestLLMProvider(BaseModel):
class DATestLLMProvider(BaseModel):
id: int
name: str
provider: str
api_key: str
default_model_name: str
is_public: bool
groups: list[TestUserGroup]
groups: list[int]
api_base: str | None = None
api_version: str | None = None
class TestDocumentSet(BaseModel):
class DATestDocumentSet(BaseModel):
id: int
name: str
description: str
@@ -103,7 +103,7 @@ class TestDocumentSet(BaseModel):
groups: list[int] = Field(default_factory=list)
class TestPersona(BaseModel):
class DATestPersona(BaseModel):
id: int
name: str
description: str
@@ -122,13 +122,13 @@ class TestPersona(BaseModel):
#
class TestChatSession(BaseModel):
class DATestChatSession(BaseModel):
id: int
persona_id: int
description: str
class TestChatMessage(BaseModel):
class DATestChatMessage(BaseModel):
id: str | None = None
chat_session_id: int
parent_message_id: str | None

View File

@@ -3,7 +3,7 @@ import requests
from danswer.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
class TestVespaClient:
class vespa_fixture:
def __init__(self, index_name: str):
self.index_name = index_name
self.vespa_document_url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)

View File

@@ -7,7 +7,7 @@ from sqlalchemy.orm import Session
from danswer.db.engine import get_session_context_manager
from danswer.db.search_settings import get_current_search_settings
from tests.integration.common_utils.reset import reset_all
from tests.integration.common_utils.vespa import TestVespaClient
from tests.integration.common_utils.vespa import vespa_fixture
def load_env_vars(env_file: str = ".env") -> None:
@@ -36,9 +36,9 @@ def db_session() -> Generator[Session, None, None]:
@pytest.fixture
def vespa_client(db_session: Session) -> TestVespaClient:
def vespa_client(db_session: Session) -> vespa_fixture:
search_settings = get_current_search_settings(db_session)
return TestVespaClient(index_name=search_settings.index_name)
return vespa_fixture(index_name=search_settings.index_name)
@pytest.fixture

View File

@@ -22,17 +22,17 @@ from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.document_set import DocumentSetManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import TestAPIKey
from tests.integration.common_utils.test_models import TestUser
from tests.integration.common_utils.test_models import TestUserGroup
from tests.integration.common_utils.vespa import TestVespaClient
from tests.integration.common_utils.test_models import DATestAPIKey
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.test_models import DATestUserGroup
from tests.integration.common_utils.vespa import vespa_fixture
def test_connector_deletion(reset: None, vespa_client: TestVespaClient) -> None:
def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: TestUser = UserManager.create(name="admin_user")
# add api key to user
api_key: TestAPIKey = APIKeyManager.create(
admin_user: DATestUser = UserManager.create(name="admin_user")
# create api key
api_key: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user,
)
@@ -76,11 +76,11 @@ def test_connector_deletion(reset: None, vespa_client: TestVespaClient) -> None:
print("Document sets created and synced")
# create user groups
user_group_1: TestUserGroup = UserGroupManager.create(
user_group_1: DATestUserGroup = UserGroupManager.create(
cc_pair_ids=[cc_pair_1.id],
user_performing_action=admin_user,
)
user_group_2: TestUserGroup = UserGroupManager.create(
user_group_2: DATestUserGroup = UserGroupManager.create(
cc_pair_ids=[cc_pair_1.id, cc_pair_2.id],
user_performing_action=admin_user,
)
@@ -174,15 +174,15 @@ def test_connector_deletion(reset: None, vespa_client: TestVespaClient) -> None:
def test_connector_deletion_for_overlapping_connectors(
reset: None, vespa_client: TestVespaClient
reset: None, vespa_client: vespa_fixture
) -> None:
"""Checks to make sure that connectors with overlapping documents work properly. Specifically, that the overlapping
document (1) still exists and (2) has the right document set / group post-deletion of one of the connectors.
"""
# Creating an admin user (first user created is automatically an admin)
admin_user: TestUser = UserManager.create(name="admin_user")
# add api key to user
api_key: TestAPIKey = APIKeyManager.create(
admin_user: DATestUser = UserManager.create(name="admin_user")
# create api key
api_key: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user,
)
@@ -251,7 +251,7 @@ def test_connector_deletion_for_overlapping_connectors(
)
# create a user group and attach it to connector 1
user_group_1: TestUserGroup = UserGroupManager.create(
user_group_1: DATestUserGroup = UserGroupManager.create(
name="Test User Group 1",
cc_pair_ids=[cc_pair_1.id],
user_performing_action=admin_user,
@@ -265,7 +265,7 @@ def test_connector_deletion_for_overlapping_connectors(
print("User group 1 created and synced")
# create a user group and attach it to connector 2
user_group_2: TestUserGroup = UserGroupManager.create(
user_group_2: DATestUserGroup = UserGroupManager.create(
name="Test User Group 2",
cc_pair_ids=[cc_pair_2.id],
user_performing_action=admin_user,

View File

@@ -2,25 +2,25 @@ import requests
from danswer.configs.constants import MessageType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.llm import LLMProviderManager
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import TestAPIKey
from tests.integration.common_utils.test_models import TestCCPair
from tests.integration.common_utils.test_models import TestUser
from tests.integration.common_utils.test_models import DATestAPIKey
from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestUser
def test_all_stream_chat_message_objects_outputs(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: TestUser = UserManager.create(name="admin_user")
admin_user: DATestUser = UserManager.create(name="admin_user")
# create connector
cc_pair_1: TestCCPair = CCPairManager.create_from_scratch(
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
user_performing_action=admin_user,
)
api_key: TestAPIKey = APIKeyManager.create(
api_key: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user,
)
LLMProviderManager.create(user_performing_action=admin_user)

View File

@@ -3,25 +3,25 @@ import requests
from danswer.configs.constants import MessageType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import NUM_DOCS
from tests.integration.common_utils.llm import LLMProviderManager
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import TestAPIKey
from tests.integration.common_utils.test_models import TestCCPair
from tests.integration.common_utils.test_models import TestUser
from tests.integration.common_utils.test_models import DATestAPIKey
from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestUser
def test_send_message_simple_with_history(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: TestUser = UserManager.create(name="admin_user")
admin_user: DATestUser = UserManager.create(name="admin_user")
# create connectors
cc_pair_1: TestCCPair = CCPairManager.create_from_scratch(
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
user_performing_action=admin_user,
)
api_key: TestAPIKey = APIKeyManager.create(
api_key: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user,
)
LLMProviderManager.create(user_performing_action=admin_user)
@@ -64,13 +64,13 @@ def test_send_message_simple_with_history(reset: None) -> None:
def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: TestUser = UserManager.create(name="admin_user")
admin_user: DATestUser = UserManager.create(name="admin_user")
# create connector
cc_pair_1: TestCCPair = CCPairManager.create_from_scratch(
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
user_performing_action=admin_user,
)
api_key: TestAPIKey = APIKeyManager.create(
api_key: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user,
)
LLMProviderManager.create(user_performing_action=admin_user)

View File

@@ -5,19 +5,19 @@ from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.document_set import DocumentSetManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import TestAPIKey
from tests.integration.common_utils.test_models import TestUser
from tests.integration.common_utils.vespa import TestVespaClient
from tests.integration.common_utils.test_models import DATestAPIKey
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.vespa import vespa_fixture
def test_multiple_document_sets_syncing_same_connnector(
reset: None, vespa_client: TestVespaClient
reset: None, vespa_client: vespa_fixture
) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: TestUser = UserManager.create(name="admin_user")
admin_user: DATestUser = UserManager.create(name="admin_user")
# add api key to user
api_key: TestAPIKey = APIKeyManager.create(
# create api key
api_key: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user,
)
@@ -66,12 +66,12 @@ def test_multiple_document_sets_syncing_same_connnector(
)
def test_removing_connector(reset: None, vespa_client: TestVespaClient) -> None:
def test_removing_connector(reset: None, vespa_client: vespa_fixture) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: TestUser = UserManager.create(name="admin_user")
admin_user: DATestUser = UserManager.create(name="admin_user")
# add api key to user
api_key: TestAPIKey = APIKeyManager.create(
# create api key
api_key: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user,
)

View File

@@ -10,17 +10,17 @@ from danswer.server.documents.models import DocumentSource
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.connector import ConnectorManager
from tests.integration.common_utils.managers.credential import CredentialManager
from tests.integration.common_utils.managers.user import TestUser
from tests.integration.common_utils.managers.user import DATestUser
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
def test_cc_pair_permissions(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: TestUser = UserManager.create(name="admin_user")
admin_user: DATestUser = UserManager.create(name="admin_user")
# Creating a curator
curator: TestUser = UserManager.create(name="curator")
curator: DATestUser = UserManager.create(name="curator")
# Creating a user group
user_group_1 = UserGroupManager.create(

Some files were not shown because too many files have changed in this diff Show More