mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-23 18:55:45 +00:00
Compare commits
52 Commits
debug-test
...
v0.7.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1f0af86454 | ||
|
|
0e6524dd32 | ||
|
|
2be133d784 | ||
|
|
cb668bcff5 | ||
|
|
756385e3ac | ||
|
|
1966127bd4 | ||
|
|
3ac84da698 | ||
|
|
7c7f5b37f5 | ||
|
|
0bf9243891 | ||
|
|
cfe4bbe3c7 | ||
|
|
9d18b92b90 | ||
|
|
74315e21b3 | ||
|
|
f9a5b227a1 | ||
|
|
3e511497d2 | ||
|
|
b0056907fb | ||
|
|
728a41a35a | ||
|
|
ef8dda2d47 | ||
|
|
15283b3140 | ||
|
|
e159b2e947 | ||
|
|
9155800fab | ||
|
|
a392ef0541 | ||
|
|
5679f0af61 | ||
|
|
ff8db71cb5 | ||
|
|
1cff2b82fd | ||
|
|
50dd3c8beb | ||
|
|
66a459234d | ||
|
|
19e57474dc | ||
|
|
f9638f2ea5 | ||
|
|
fbf51b70d0 | ||
|
|
b97cc01bb2 | ||
|
|
6d48fd5d99 | ||
|
|
1f61447b4b | ||
|
|
deee2b3513 | ||
|
|
b73d66c84a | ||
|
|
c5a61f4820 | ||
|
|
ea4a3cbf86 | ||
|
|
166514cedf | ||
|
|
be50ae1e71 | ||
|
|
f89504ec53 | ||
|
|
6b3213b1e4 | ||
|
|
48577bf0e4 | ||
|
|
c59d1ff0a5 | ||
|
|
ba38dec592 | ||
|
|
f5adc3063e | ||
|
|
8cfe80c53a | ||
|
|
487250320b | ||
|
|
c8d13922a9 | ||
|
|
cb75449cec | ||
|
|
b66514cd21 | ||
|
|
77650c9ee3 | ||
|
|
316b6b99ea | ||
|
|
34c2aa0860 |
@@ -7,16 +7,17 @@ on:
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-backend
|
||||
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
# TODO: make this a matrix build like the web containers
|
||||
runs-on:
|
||||
group: amd64-image-builders
|
||||
# TODO: investigate a matrix build like the web container
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -31,7 +32,7 @@ jobs:
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential
|
||||
|
||||
|
||||
- name: Backend Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
@@ -41,12 +42,20 @@ jobs:
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.REGISTRY_IMAGE }}:latest
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
|
||||
@@ -5,14 +5,18 @@ on:
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-model-server
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on:
|
||||
group: amd64-image-builders
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -31,13 +35,21 @@ jobs:
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
danswer/danswer-model-server:${{ github.ref_name }}
|
||||
danswer/danswer-model-server:latest
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
|
||||
@@ -7,7 +7,8 @@ on:
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-web-server
|
||||
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on:
|
||||
@@ -35,7 +36,7 @@ jobs:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
tags: |
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -112,8 +113,16 @@ jobs:
|
||||
run: |
|
||||
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
|
||||
7
.github/workflows/docker-tag-latest.yml
vendored
7
.github/workflows/docker-tag-latest.yml
vendored
@@ -1,3 +1,6 @@
|
||||
# This workflow is set up to be manually triggered via the GitHub Action tab.
|
||||
# Given a version, it will tag those backend and webserver images as "latest".
|
||||
|
||||
name: Tag Latest Version
|
||||
|
||||
on:
|
||||
@@ -9,7 +12,9 @@ on:
|
||||
|
||||
jobs:
|
||||
tag:
|
||||
runs-on: ubuntu-latest
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# use a lower powered instance since this just does i/o to docker hub
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
@@ -12,7 +12,8 @@ on:
|
||||
|
||||
jobs:
|
||||
lint-test:
|
||||
runs-on: Amd64
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
|
||||
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
|
||||
7
.github/workflows/pr-python-checks.yml
vendored
7
.github/workflows/pr-python-checks.yml
vendored
@@ -3,11 +3,14 @@ name: Python Checks
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
jobs:
|
||||
mypy-check:
|
||||
runs-on: ubuntu-latest
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
|
||||
@@ -15,10 +15,14 @@ env:
|
||||
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
# Jira
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
runs-on: ubuntu-latest
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
|
||||
7
.github/workflows/pr-python-tests.yml
vendored
7
.github/workflows/pr-python-tests.yml
vendored
@@ -3,11 +3,14 @@ name: Python Unit Tests
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
jobs:
|
||||
backend-check:
|
||||
runs-on: ubuntu-latest
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
|
||||
5
.github/workflows/pr-quality-checks.yml
vendored
5
.github/workflows/pr-quality-checks.yml
vendored
@@ -1,6 +1,6 @@
|
||||
name: Quality Checks PR
|
||||
concurrency:
|
||||
group: Quality-Checks-PR-${{ github.head_ref }}
|
||||
group: Quality-Checks-PR-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
@@ -9,7 +9,8 @@ on:
|
||||
|
||||
jobs:
|
||||
quality-checks:
|
||||
runs-on: ubuntu-latest
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
|
||||
11
.github/workflows/run-it.yml
vendored
11
.github/workflows/run-it.yml
vendored
@@ -1,19 +1,22 @@
|
||||
name: Run Integration Tests
|
||||
concurrency:
|
||||
group: Run-Integration-Tests-${{ github.head_ref }}
|
||||
group: Run-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
jobs:
|
||||
integration-tests:
|
||||
runs-on: Amd64
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,ram=32,"run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
@@ -120,6 +123,7 @@ jobs:
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network danswer-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
@@ -128,6 +132,7 @@ jobs:
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
danswer/integration-test-runner:it
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
54
.github/workflows/tag-nightly.yml
vendored
Normal file
54
.github/workflows/tag-nightly.yml
vendored
Normal file
@@ -0,0 +1,54 @@
|
||||
name: Nightly Tag Push
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * *' # Runs every day at midnight UTC
|
||||
|
||||
permissions:
|
||||
contents: write # Allows pushing tags to the repository
|
||||
|
||||
jobs:
|
||||
create-and-push-tag:
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
# actions using GITHUB_TOKEN cannot trigger another workflow, but we do want this to trigger docker pushes
|
||||
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
|
||||
# implement here which needs an actual user's deploy key
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
|
||||
|
||||
- name: Set up Git user
|
||||
run: |
|
||||
git config user.name "Richard Kuo [bot]"
|
||||
git config user.email "rkuo[bot]@danswer.ai"
|
||||
|
||||
- name: Check for existing nightly tag
|
||||
id: check_tag
|
||||
run: |
|
||||
if git tag --points-at HEAD --list "nightly-latest*" | grep -q .; then
|
||||
echo "A tag starting with 'nightly-latest' already exists on HEAD."
|
||||
echo "tag_exists=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "No tag starting with 'nightly-latest' exists on HEAD."
|
||||
echo "tag_exists=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
# don't tag again if HEAD already has a nightly-latest tag on it
|
||||
- name: Create Nightly Tag
|
||||
if: steps.check_tag.outputs.tag_exists == 'false'
|
||||
env:
|
||||
DATE: ${{ github.run_id }}
|
||||
run: |
|
||||
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
|
||||
echo "Creating tag: $TAG_NAME"
|
||||
git tag $TAG_NAME
|
||||
|
||||
- name: Push Tag
|
||||
if: steps.check_tag.outputs.tag_exists == 'false'
|
||||
run: |
|
||||
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
|
||||
git push origin $TAG_NAME
|
||||
|
||||
1
.prettierignore
Normal file
1
.prettierignore
Normal file
@@ -0,0 +1 @@
|
||||
backend/tests/integration/tests/pruning/website
|
||||
@@ -0,0 +1,46 @@
|
||||
"""fix_user__external_user_group_id_fk
|
||||
|
||||
Revision ID: 46b7a812670f
|
||||
Revises: f32615f71aeb
|
||||
Create Date: 2024-09-23 12:58:03.894038
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "46b7a812670f"
|
||||
down_revision = "f32615f71aeb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the existing primary key
|
||||
op.drop_constraint(
|
||||
"user__external_user_group_id_pkey",
|
||||
"user__external_user_group_id",
|
||||
type_="primary",
|
||||
)
|
||||
|
||||
# Add the new composite primary key
|
||||
op.create_primary_key(
|
||||
"user__external_user_group_id_pkey",
|
||||
"user__external_user_group_id",
|
||||
["user_id", "external_user_group_id", "cc_pair_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the composite primary key
|
||||
op.drop_constraint(
|
||||
"user__external_user_group_id_pkey",
|
||||
"user__external_user_group_id",
|
||||
type_="primary",
|
||||
)
|
||||
# Delete all entries from the table
|
||||
op.execute("DELETE FROM user__external_user_group_id")
|
||||
|
||||
# Recreate the original primary key on user_id
|
||||
op.create_primary_key(
|
||||
"user__external_user_group_id_pkey", "user__external_user_group_id", ["user_id"]
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,6 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -69,6 +70,30 @@ def get_deletion_attempt_snapshot(
|
||||
)
|
||||
|
||||
|
||||
def skip_cc_pair_pruning_by_task(
|
||||
pruning_task: TaskQueueState | None, db_session: Session
|
||||
) -> bool:
|
||||
"""task should be the latest prune task for this cc_pair"""
|
||||
if not ALLOW_SIMULTANEOUS_PRUNING:
|
||||
# if only one prune is allowed at any time, then check to see if any prune
|
||||
# is active
|
||||
pruning_type_task_name = name_cc_prune_task()
|
||||
last_pruning_type_task = get_latest_task_by_type(
|
||||
pruning_type_task_name, db_session
|
||||
)
|
||||
|
||||
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
|
||||
last_pruning_type_task, db_session
|
||||
):
|
||||
return True
|
||||
|
||||
if pruning_task and check_task_is_live_and_not_timed_out(pruning_task, db_session):
|
||||
# if the last task is live right now, we shouldn't start a new one
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def should_prune_cc_pair(
|
||||
connector: Connector, credential: Credential, db_session: Session
|
||||
) -> bool:
|
||||
@@ -79,31 +104,26 @@ def should_prune_cc_pair(
|
||||
connector_id=connector.id, credential_id=credential.id
|
||||
)
|
||||
last_pruning_task = get_latest_task(pruning_task_name, db_session)
|
||||
|
||||
if skip_cc_pair_pruning_by_task(last_pruning_task, db_session):
|
||||
return False
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
|
||||
if not last_pruning_task:
|
||||
# If the connector has never been pruned, then compare vs when the connector
|
||||
# was created
|
||||
time_since_initialization = current_db_time - connector.time_created
|
||||
if time_since_initialization.total_seconds() >= connector.prune_freq:
|
||||
return True
|
||||
return False
|
||||
|
||||
if not ALLOW_SIMULTANEOUS_PRUNING:
|
||||
pruning_type_task_name = name_cc_prune_task()
|
||||
last_pruning_type_task = get_latest_task_by_type(
|
||||
pruning_type_task_name, db_session
|
||||
)
|
||||
|
||||
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
|
||||
last_pruning_type_task, db_session
|
||||
):
|
||||
return False
|
||||
|
||||
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
|
||||
return False
|
||||
|
||||
if not last_pruning_task.start_time:
|
||||
# if the last prune task hasn't started, we shouldn't start a new one
|
||||
return False
|
||||
|
||||
# if the last prune task has a start time, then compare against it to determine
|
||||
# if we should start
|
||||
time_since_last_pruning = current_db_time - last_pruning_task.start_time
|
||||
return time_since_last_pruning.total_seconds() >= connector.prune_freq
|
||||
|
||||
@@ -141,3 +161,30 @@ def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> se
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
|
||||
|
||||
return all_connector_doc_ids
|
||||
|
||||
|
||||
def celery_is_listening_to_queue(worker: Any, name: str) -> bool:
|
||||
"""Checks to see if we're listening to the named queue"""
|
||||
|
||||
# how to get a list of queues this worker is listening to
|
||||
# https://stackoverflow.com/questions/29790523/how-to-determine-which-queues-a-celery-worker-is-consuming-at-runtime
|
||||
queue_names = list(worker.app.amqp.queues.consume_from.keys())
|
||||
for queue_name in queue_names:
|
||||
if queue_name == name:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def celery_is_worker_primary(worker: Any) -> bool:
|
||||
"""There are multiple approaches that could be taken, but the way we do it is to
|
||||
check the hostname set for the celery worker, either in celeryconfig.py or on the
|
||||
command line."""
|
||||
hostname = worker.hostname
|
||||
if hostname.startswith("light"):
|
||||
return False
|
||||
|
||||
if hostname.startswith("heavy"):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
|
||||
from danswer.configs.app_configs import CELERY_BROKER_POOL_LIMIT
|
||||
from danswer.configs.app_configs import CELERY_RESULT_EXPIRES
|
||||
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY
|
||||
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY_RESULT_BACKEND
|
||||
from danswer.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
|
||||
from danswer.configs.app_configs import REDIS_HOST
|
||||
from danswer.configs.app_configs import REDIS_PASSWORD
|
||||
from danswer.configs.app_configs import REDIS_PORT
|
||||
@@ -9,6 +11,7 @@ from danswer.configs.app_configs import REDIS_SSL
|
||||
from danswer.configs.app_configs import REDIS_SSL_CA_CERTS
|
||||
from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
|
||||
|
||||
CELERY_SEPARATOR = ":"
|
||||
|
||||
@@ -36,12 +39,30 @@ result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PO
|
||||
# can stall other tasks.
|
||||
worker_prefetch_multiplier = 4
|
||||
|
||||
broker_connection_retry_on_startup = True
|
||||
broker_pool_limit = CELERY_BROKER_POOL_LIMIT
|
||||
|
||||
# redis broker settings
|
||||
# https://docs.celeryq.dev/projects/kombu/en/stable/reference/kombu.transport.redis.html
|
||||
broker_transport_options = {
|
||||
"priority_steps": list(range(len(DanswerCeleryPriority))),
|
||||
"sep": CELERY_SEPARATOR,
|
||||
"queue_order_strategy": "priority",
|
||||
"retry_on_timeout": True,
|
||||
"health_check_interval": REDIS_HEALTH_CHECK_INTERVAL,
|
||||
"socket_keepalive": True,
|
||||
"socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS,
|
||||
}
|
||||
|
||||
# redis backend settings
|
||||
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
|
||||
|
||||
# there doesn't appear to be a way to set socket_keepalive_options on the redis result backend
|
||||
redis_socket_keepalive = True
|
||||
redis_retry_on_timeout = True
|
||||
redis_backend_health_check_interval = REDIS_HEALTH_CHECK_INTERVAL
|
||||
|
||||
|
||||
task_default_priority = DanswerCeleryPriority.MEDIUM
|
||||
task_acks_late = True
|
||||
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
import redis
|
||||
from celery import shared_task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.utils.log import get_task_logger
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.index_attempt import get_last_attempt
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.redis.redis_pool import RedisPool
|
||||
|
||||
redis_pool = RedisPool()
|
||||
|
||||
# use this within celery tasks to get celery task specific logging
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_connector_deletion_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
)
|
||||
def check_for_connector_deletion_task() -> None:
|
||||
r = redis_pool.get_client()
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
try_generate_document_cc_pair_cleanup_tasks(
|
||||
cc_pair, db_session, r, lock_beat
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_generate_document_cc_pair_cleanup_tasks(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Note that syncing can still be required even if the number of sync tasks generated is zero.
|
||||
Returns None if no syncing is required.
|
||||
"""
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
if r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
# we need to refresh the state of the object inside the fence
|
||||
# to avoid a race condition with db.commit/fence deletion
|
||||
# at the end of this taskset
|
||||
try:
|
||||
db_session.refresh(cc_pair)
|
||||
except ObjectDeletedError:
|
||||
return None
|
||||
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
last_indexing = get_last_attempt(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
search_settings_id=search_settings.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
if last_indexing:
|
||||
if (
|
||||
last_indexing.status == IndexingStatus.IN_PROGRESS
|
||||
or last_indexing.status == IndexingStatus.NOT_STARTED
|
||||
):
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rcd.taskset_key)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
||||
)
|
||||
tasks_generated = rcd.generate_tasks(celery_app, db_session, r, lock_beat)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks finished. "
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rcd.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
140
backend/danswer/background/celery/tasks/periodic/tasks.py
Normal file
140
backend/danswer/background/celery/tasks/periodic/tasks.py
Normal file
@@ -0,0 +1,140 @@
|
||||
#####
|
||||
# Periodic Tasks
|
||||
#####
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from celery.contrib.abortable import AbortableTask # type: ignore
|
||||
from celery.exceptions import TaskRevokedError
|
||||
from celery.utils.log import get_task_logger
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import PostgresAdvisoryLocks
|
||||
from danswer.db.engine import get_sqlalchemy_engine # type: ignore
|
||||
|
||||
# use this within celery tasks to get celery task specific logging
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="kombu_message_cleanup_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
base=AbortableTask,
|
||||
)
|
||||
def kombu_message_cleanup_task(self: Any) -> int:
|
||||
"""Runs periodically to clean up the kombu_message table"""
|
||||
|
||||
# we will select messages older than this amount to clean up
|
||||
KOMBU_MESSAGE_CLEANUP_AGE = 7 # days
|
||||
KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT = 1000
|
||||
|
||||
ctx = {}
|
||||
ctx["last_processed_id"] = 0
|
||||
ctx["deleted"] = 0
|
||||
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
|
||||
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# Exit the task if we can't take the advisory lock
|
||||
result = db_session.execute(
|
||||
text("SELECT pg_try_advisory_lock(:id)"),
|
||||
{"id": PostgresAdvisoryLocks.KOMBU_MESSAGE_CLEANUP_LOCK_ID.value},
|
||||
).scalar()
|
||||
if not result:
|
||||
return 0
|
||||
|
||||
while True:
|
||||
if self.is_aborted():
|
||||
raise TaskRevokedError("kombu_message_cleanup_task was aborted.")
|
||||
|
||||
b = kombu_message_cleanup_task_helper(ctx, db_session)
|
||||
if not b:
|
||||
break
|
||||
|
||||
db_session.commit()
|
||||
|
||||
if ctx["deleted"] > 0:
|
||||
task_logger.info(
|
||||
f"Deleted {ctx['deleted']} orphaned messages from kombu_message."
|
||||
)
|
||||
|
||||
return ctx["deleted"]
|
||||
|
||||
|
||||
def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool:
|
||||
"""
|
||||
Helper function to clean up old messages from the `kombu_message` table that are no longer relevant.
|
||||
|
||||
This function retrieves messages from the `kombu_message` table that are no longer visible and
|
||||
older than a specified interval. It checks if the corresponding task_id exists in the
|
||||
`celery_taskmeta` table. If the task_id does not exist, the message is deleted.
|
||||
|
||||
Args:
|
||||
ctx (dict): A context dictionary containing configuration parameters such as:
|
||||
- 'cleanup_age' (int): The age in days after which messages are considered old.
|
||||
- 'page_limit' (int): The maximum number of messages to process in one batch.
|
||||
- 'last_processed_id' (int): The ID of the last processed message to handle pagination.
|
||||
- 'deleted' (int): A counter to track the number of deleted messages.
|
||||
db_session (Session): The SQLAlchemy database session for executing queries.
|
||||
|
||||
Returns:
|
||||
bool: Returns True if there are more rows to process, False if not.
|
||||
"""
|
||||
|
||||
inspector = inspect(db_session.bind)
|
||||
if not inspector:
|
||||
return False
|
||||
|
||||
# With the move to redis as celery's broker and backend, kombu tables may not even exist.
|
||||
# We can fail silently.
|
||||
if not inspector.has_table("kombu_message"):
|
||||
return False
|
||||
|
||||
query = text(
|
||||
"""
|
||||
SELECT id, timestamp, payload
|
||||
FROM kombu_message WHERE visible = 'false'
|
||||
AND timestamp < CURRENT_TIMESTAMP - INTERVAL :interval_days
|
||||
AND id > :last_processed_id
|
||||
ORDER BY id
|
||||
LIMIT :page_limit
|
||||
"""
|
||||
)
|
||||
kombu_messages = db_session.execute(
|
||||
query,
|
||||
{
|
||||
"interval_days": f"{ctx['cleanup_age']} days",
|
||||
"page_limit": ctx["page_limit"],
|
||||
"last_processed_id": ctx["last_processed_id"],
|
||||
},
|
||||
).fetchall()
|
||||
|
||||
if len(kombu_messages) == 0:
|
||||
return False
|
||||
|
||||
for msg in kombu_messages:
|
||||
payload = json.loads(msg[2])
|
||||
task_id = payload["headers"]["id"]
|
||||
|
||||
# Check if task_id exists in celery_taskmeta
|
||||
task_exists = db_session.execute(
|
||||
text("SELECT 1 FROM celery_taskmeta WHERE task_id = :task_id"),
|
||||
{"task_id": task_id},
|
||||
).fetchone()
|
||||
|
||||
# If task_id does not exist, delete the message
|
||||
if not task_exists:
|
||||
result = db_session.execute(
|
||||
text("DELETE FROM kombu_message WHERE id = :message_id"),
|
||||
{"message_id": msg[0]},
|
||||
)
|
||||
if result.rowcount > 0: # type: ignore
|
||||
ctx["deleted"] += 1
|
||||
|
||||
ctx["last_processed_id"] = msg[0]
|
||||
|
||||
return True
|
||||
120
backend/danswer/background/celery/tasks/pruning/tasks.py
Normal file
120
backend/danswer/background/celery/tasks/pruning/tasks.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from celery import shared_task
|
||||
from celery.utils.log import get_task_logger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from danswer.background.celery.celery_utils import should_prune_cc_pair
|
||||
from danswer.background.connector_deletion import delete_connector_credential_pair_batch
|
||||
from danswer.background.task_utils import build_celery_task_wrapper
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
|
||||
|
||||
# use this within celery tasks to get celery task specific logging
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_prune_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_prune_task() -> None:
|
||||
"""Runs periodically to check if any prune tasks should be run and adds them
|
||||
to the queue"""
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session)
|
||||
|
||||
for cc_pair in all_cc_pairs:
|
||||
if should_prune_cc_pair(
|
||||
connector=cc_pair.connector,
|
||||
credential=cc_pair.credential,
|
||||
db_session=db_session,
|
||||
):
|
||||
task_logger.info(f"Pruning the {cc_pair.connector.name} connector")
|
||||
|
||||
prune_documents_task.apply_async(
|
||||
kwargs=dict(
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_cc_prune_task)
|
||||
@celery_app.task(name="prune_documents_task", soft_time_limit=JOB_TIMEOUT)
|
||||
def prune_documents_task(connector_id: int, credential_id: int) -> None:
|
||||
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
try:
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"ccpair not found for {connector_id} {credential_id}"
|
||||
)
|
||||
return
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session,
|
||||
cc_pair.connector.source,
|
||||
InputType.PRUNE,
|
||||
cc_pair.connector.connector_specific_config,
|
||||
cc_pair.credential,
|
||||
)
|
||||
|
||||
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
|
||||
runnable_connector
|
||||
)
|
||||
|
||||
all_indexed_document_ids = {
|
||||
doc.id
|
||||
for doc in get_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
}
|
||||
|
||||
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
|
||||
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
if len(doc_ids_to_remove) == 0:
|
||||
task_logger.info(
|
||||
f"No docs to prune from {cc_pair.connector.source} connector"
|
||||
)
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector"
|
||||
)
|
||||
delete_connector_credential_pair_batch(
|
||||
document_ids=doc_ids_to_remove,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
document_index=document_index,
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run pruning for connector id {connector_id}."
|
||||
)
|
||||
raise e
|
||||
526
backend/danswer/background/celery/tasks/vespa/tasks.py
Normal file
526
backend/danswer/background/celery/tasks/vespa/tasks.py
Normal file
@@ -0,0 +1,526 @@
|
||||
import traceback
|
||||
from typing import cast
|
||||
|
||||
import redis
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.utils.log import get_task_logger
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector_credential_pair import add_deletion_failure_message
|
||||
from danswer.db.connector_credential_pair import (
|
||||
delete_connector_credential_pair__no_commit,
|
||||
)
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.document import count_documents_by_needs_sync
|
||||
from danswer.db.document import get_document
|
||||
from danswer.db.document import mark_document_as_synced
|
||||
from danswer.db.document_set import delete_document_set
|
||||
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from danswer.db.document_set import fetch_document_sets
|
||||
from danswer.db.document_set import fetch_document_sets_for_document
|
||||
from danswer.db.document_set import get_document_set_by_id
|
||||
from danswer.db.document_set import mark_document_set_as_synced
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import delete_index_attempts
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import UserGroup
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.redis.redis_pool import RedisPool
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from danswer.utils.variable_functionality import noop_fallback
|
||||
|
||||
redis_pool = RedisPool()
|
||||
|
||||
# use this within celery tasks to get celery task specific logging
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
# celery auto associates tasks created inside another task,
|
||||
# which bloats the result metadata considerably. trail=False prevents this.
|
||||
@shared_task(
|
||||
name="check_for_vespa_sync_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
)
|
||||
def check_for_vespa_sync_task() -> None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
|
||||
r = redis_pool.get_client()
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
try_generate_stale_document_sync_tasks(db_session, r, lock_beat)
|
||||
|
||||
# check if any document sets are not synced
|
||||
document_set_info = fetch_document_sets(
|
||||
user_id=None, db_session=db_session, include_outdated=True
|
||||
)
|
||||
for document_set, _ in document_set_info:
|
||||
try_generate_document_set_sync_tasks(
|
||||
document_set, db_session, r, lock_beat
|
||||
)
|
||||
|
||||
# check if any user groups are not synced
|
||||
try:
|
||||
fetch_user_groups = fetch_versioned_implementation(
|
||||
"danswer.db.user_group", "fetch_user_groups"
|
||||
)
|
||||
|
||||
user_groups = fetch_user_groups(
|
||||
db_session=db_session, only_up_to_date=False
|
||||
)
|
||||
for usergroup in user_groups:
|
||||
try_generate_user_group_sync_tasks(
|
||||
usergroup, db_session, r, lock_beat
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
# Always exceptions on the MIT version, which is expected
|
||||
pass
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_generate_stale_document_sync_tasks(
|
||||
db_session: Session, r: Redis, lock_beat: redis.lock.Lock
|
||||
) -> int | None:
|
||||
# the fence is up, do nothing
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
return None
|
||||
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key()) # delete the taskset
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
stale_doc_count = count_documents_by_needs_sync(db_session)
|
||||
if stale_doc_count == 0:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
|
||||
)
|
||||
|
||||
task_logger.info("RedisConnector.generate_tasks starting by cc_pair.")
|
||||
|
||||
# rkuo: we could technically sync all stale docs in one big pass.
|
||||
# but I feel it's more understandable to group the docs by cc_pair
|
||||
total_tasks_generated = 0
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
rc = RedisConnectorCredentialPair(cc_pair.id)
|
||||
tasks_generated = rc.generate_tasks(celery_app, db_session, r, lock_beat)
|
||||
|
||||
if tasks_generated is None:
|
||||
continue
|
||||
|
||||
if tasks_generated == 0:
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks finished for single cc_pair. "
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
total_tasks_generated += tasks_generated
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
|
||||
)
|
||||
|
||||
r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated)
|
||||
return total_tasks_generated
|
||||
|
||||
|
||||
def try_generate_document_set_sync_tasks(
|
||||
document_set: DocumentSet, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rds = RedisDocumentSet(document_set.id)
|
||||
|
||||
# don't generate document set sync tasks if tasks are still pending
|
||||
if r.exists(rds.fence_key):
|
||||
return None
|
||||
|
||||
# don't generate sync tasks if we're up to date
|
||||
# race condition with the monitor/cleanup function if we use a cached result!
|
||||
db_session.refresh(document_set)
|
||||
if document_set.is_up_to_date:
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rds.taskset_key)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisDocumentSet.generate_tasks starting. document_set_id={document_set.id}"
|
||||
)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
tasks_generated = rds.generate_tasks(celery_app, db_session, r, lock_beat)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"RedisDocumentSet.generate_tasks finished. "
|
||||
f"document_set_id={document_set.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rds.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
||||
|
||||
def try_generate_user_group_sync_tasks(
|
||||
usergroup: UserGroup, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rug = RedisUserGroup(usergroup.id)
|
||||
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
if r.exists(rug.fence_key):
|
||||
return None
|
||||
|
||||
# race condition with the monitor/cleanup function if we use a cached result!
|
||||
db_session.refresh(usergroup)
|
||||
if usergroup.is_up_to_date:
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rug.taskset_key)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
task_logger.info(
|
||||
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
|
||||
)
|
||||
tasks_generated = rug.generate_tasks(celery_app, db_session, r, lock_beat)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"RedisUserGroup.generate_tasks finished. "
|
||||
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rug.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
||||
|
||||
def monitor_connector_taskset(r: Redis) -> None:
|
||||
fence_value = r.get(RedisConnectorCredentialPair.get_fence_key())
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = r.scard(RedisConnectorCredentialPair.get_taskset_key())
|
||||
task_logger.info(
|
||||
f"Stale document sync progress: remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count == 0:
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||
task_logger.info(f"Successfully synced stale documents. count={initial_count}")
|
||||
|
||||
|
||||
def monitor_document_set_taskset(
|
||||
key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key)
|
||||
if document_set_id is None:
|
||||
task_logger.warning("could not parse document set id from {key}")
|
||||
return
|
||||
|
||||
rds = RedisDocumentSet(document_set_id)
|
||||
|
||||
fence_value = r.get(rds.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rds.taskset_key))
|
||||
task_logger.info(
|
||||
f"Document set sync progress: document_set_id={document_set_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
document_set = cast(
|
||||
DocumentSet,
|
||||
get_document_set_by_id(db_session=db_session, document_set_id=document_set_id),
|
||||
) # casting since we "know" a document set with this ID exists
|
||||
if document_set:
|
||||
if not document_set.connector_credential_pairs:
|
||||
# if there are no connectors, then delete the document set.
|
||||
delete_document_set(document_set_row=document_set, db_session=db_session)
|
||||
task_logger.info(
|
||||
f"Successfully deleted document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
else:
|
||||
mark_document_set_as_synced(document_set_id, db_session)
|
||||
task_logger.info(
|
||||
f"Successfully synced document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
|
||||
r.delete(rds.taskset_key)
|
||||
r.delete(rds.fence_key)
|
||||
|
||||
|
||||
def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id is None:
|
||||
task_logger.warning("could not parse document set id from {key}")
|
||||
return
|
||||
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
|
||||
fence_value = r.get(rcd.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rcd.taskset_key))
|
||||
task_logger.info(
|
||||
f"Connector deletion progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if not cc_pair:
|
||||
return
|
||||
|
||||
try:
|
||||
# clean up the rest of the related Postgres entities
|
||||
# index attempts
|
||||
delete_index_attempts(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
)
|
||||
|
||||
# document sets
|
||||
delete_document_set_cc_pair_relationship__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
# user groups
|
||||
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.user_group",
|
||||
"delete_user_group_cc_pair_relationship__no_commit",
|
||||
noop_fallback,
|
||||
)
|
||||
cleanup_user_groups(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
# if there are no credentials left, delete the connector
|
||||
connector = fetch_connector_by_id(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
task_logger.info(
|
||||
"Found no credentials left for connector, deleting connector"
|
||||
)
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
stack_trace = traceback.format_exc()
|
||||
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
|
||||
add_deletion_failure_message(db_session, cc_pair.id, error_message)
|
||||
task_logger.exception(
|
||||
f"Failed to run connector_deletion. "
|
||||
f"connector_id={cc_pair.connector_id} credential_id={cc_pair.credential_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
f"Successfully deleted connector_credential_pair with connector_id: '{cc_pair.connector_id}' "
|
||||
f"and credential_id: '{cc_pair.credential_id}'. "
|
||||
f"Deleted {initial_count} docs."
|
||||
)
|
||||
|
||||
r.delete(rcd.taskset_key)
|
||||
r.delete(rcd.fence_key)
|
||||
|
||||
|
||||
@shared_task(name="monitor_vespa_sync", soft_time_limit=300)
|
||||
def monitor_vespa_sync() -> None:
|
||||
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||
It scans for fence values and then gets the counts of any associated tasksets.
|
||||
If the count is 0, that means all tasks finished and we should clean up.
|
||||
|
||||
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
|
||||
do anything too expensive in this function!
|
||||
"""
|
||||
r = redis_pool.get_client()
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# prevent overlapping tasks
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
|
||||
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
monitor_connector_deletion_taskset(key_bytes, r)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
monitor_document_set_taskset(key_bytes, r, db_session)
|
||||
|
||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
monitor_usergroup_taskset = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"danswer.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
)
|
||||
monitor_usergroup_taskset(key_bytes, r, db_session)
|
||||
|
||||
# uncomment for debugging if needed
|
||||
# r_celery = celery_app.broker_connection().channel().client
|
||||
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
# task_logger.warning(f"queue={DanswerCeleryQueues.VESPA_METADATA_SYNC} length={length}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="vespa_metadata_sync_task",
|
||||
bind=True,
|
||||
soft_time_limit=45,
|
||||
time_limit=60,
|
||||
max_retries=3,
|
||||
)
|
||||
def vespa_metadata_sync_task(self: Task, document_id: str) -> bool:
|
||||
task_logger.info(f"document_id={document_id}")
|
||||
|
||||
try:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
doc = get_document(document_id, db_session)
|
||||
if not doc:
|
||||
return False
|
||||
|
||||
# document set sync
|
||||
doc_sets = fetch_document_sets_for_document(document_id, db_session)
|
||||
update_doc_sets: set[str] = set(doc_sets)
|
||||
|
||||
# User group sync
|
||||
doc_access = get_access_for_document(
|
||||
document_id=document_id, db_session=db_session
|
||||
)
|
||||
update_request = UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
document_sets=update_doc_sets,
|
||||
access=doc_access,
|
||||
boost=doc.boost,
|
||||
hidden=doc.hidden,
|
||||
)
|
||||
|
||||
# update Vespa
|
||||
document_index.update(update_requests=[update_request])
|
||||
|
||||
# update db last. Worst case = we crash right before this and
|
||||
# the sync might repeat again later
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
|
||||
except Exception as e:
|
||||
task_logger.exception("Unexpected exception")
|
||||
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
|
||||
return True
|
||||
@@ -10,15 +10,27 @@ are multiple connector / credential pairs that have indexed it
|
||||
connector / credential pair from the access list
|
||||
(6) delete all relevant entries from postgres
|
||||
"""
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.utils.log import get_task_logger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.access.access import get_access_for_documents
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document
|
||||
from danswer.db.document import get_document_connector_count
|
||||
from danswer.db.document import get_document_connector_counts
|
||||
from danswer.db.document import mark_document_as_synced
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.document_set import fetch_document_sets_for_document
|
||||
from danswer.db.document_set import fetch_document_sets_for_documents
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
@@ -26,6 +38,9 @@ from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# use this within celery tasks to get celery task specific logging
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
_DELETION_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
@@ -108,3 +123,89 @@ def delete_connector_credential_pair_batch(
|
||||
),
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="document_by_cc_pair_cleanup_task",
|
||||
bind=True,
|
||||
soft_time_limit=45,
|
||||
time_limit=60,
|
||||
max_retries=3,
|
||||
)
|
||||
def document_by_cc_pair_cleanup_task(
|
||||
self: Task, document_id: str, connector_id: int, credential_id: int
|
||||
) -> bool:
|
||||
task_logger.info(f"document_id={document_id}")
|
||||
|
||||
try:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
count = get_document_connector_count(db_session, document_id)
|
||||
if count == 1:
|
||||
# count == 1 means this is the only remaining cc_pair reference to the doc
|
||||
# delete it from vespa and the db
|
||||
document_index.delete_single(doc_id=document_id)
|
||||
delete_documents_complete__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=[document_id],
|
||||
)
|
||||
elif count > 1:
|
||||
# count > 1 means the document still has cc_pair references
|
||||
doc = get_document(document_id, db_session)
|
||||
if not doc:
|
||||
return False
|
||||
|
||||
# the below functions do not include cc_pairs being deleted.
|
||||
# i.e. they will correctly omit access for the current cc_pair
|
||||
doc_access = get_access_for_document(
|
||||
document_id=document_id, db_session=db_session
|
||||
)
|
||||
|
||||
doc_sets = fetch_document_sets_for_document(document_id, db_session)
|
||||
update_doc_sets: set[str] = set(doc_sets)
|
||||
|
||||
update_request = UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
document_sets=update_doc_sets,
|
||||
access=doc_access,
|
||||
boost=doc.boost,
|
||||
hidden=doc.hidden,
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
document_index.update_single(update_request=update_request)
|
||||
|
||||
# there are still other cc_pair references to the doc, so just resync to Vespa
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_id=document_id,
|
||||
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
else:
|
||||
pass
|
||||
|
||||
# update_docs_last_modified__no_commit(
|
||||
# db_session=db_session,
|
||||
# document_ids=[document_id],
|
||||
# )
|
||||
|
||||
db_session.commit()
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
|
||||
except Exception as e:
|
||||
task_logger.exception("Unexpected exception")
|
||||
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
|
||||
return True
|
||||
|
||||
@@ -29,6 +29,7 @@ from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.embedder import DefaultIndexingEmbedder
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
|
||||
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.utils.logger import IndexAttemptSingleton
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -103,15 +104,24 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=search_settings
|
||||
search_settings=search_settings,
|
||||
heartbeat=IndexingHeartbeat(
|
||||
index_attempt_id=index_attempt.id,
|
||||
db_session=db_session,
|
||||
# let the world know we're still making progress after
|
||||
# every 10 batches
|
||||
freq=10,
|
||||
),
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
attempt_id=index_attempt.id,
|
||||
embedder=embedding_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=index_attempt.from_beginning
|
||||
or (search_settings.status == IndexModelStatus.FUTURE),
|
||||
ignore_time_skip=(
|
||||
index_attempt.from_beginning
|
||||
or (search_settings.status == IndexModelStatus.FUTURE)
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@@ -416,6 +416,7 @@ def update_loop(
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
logger.notice("First inference complete.")
|
||||
|
||||
client_primary: Client | SimpleJobClient
|
||||
client_secondary: Client | SimpleJobClient
|
||||
@@ -444,6 +445,7 @@ def update_loop(
|
||||
|
||||
existing_jobs: dict[int, Future | SimpleJob] = {}
|
||||
|
||||
logger.notice("Startup complete. Waiting for indexing jobs...")
|
||||
while True:
|
||||
start = time.time()
|
||||
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
@@ -164,13 +164,29 @@ REDIS_DB_NUMBER_CELERY_RESULT_BACKEND = int(
|
||||
)
|
||||
REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15)) # broker
|
||||
|
||||
# will propagate to both our redis client as well as celery's redis client
|
||||
REDIS_HEALTH_CHECK_INTERVAL = int(os.environ.get("REDIS_HEALTH_CHECK_INTERVAL", 60))
|
||||
|
||||
# our redis client only, not celery's
|
||||
REDIS_POOL_MAX_CONNECTIONS = int(os.environ.get("REDIS_POOL_MAX_CONNECTIONS", 128))
|
||||
|
||||
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
|
||||
# should be one of "required", "optional", or "none"
|
||||
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "none")
|
||||
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", "")
|
||||
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", None)
|
||||
|
||||
CELERY_RESULT_EXPIRES = int(os.environ.get("CELERY_RESULT_EXPIRES", 86400)) # seconds
|
||||
|
||||
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#broker-pool-limit
|
||||
# Setting to None may help when there is a proxy in the way closing idle connections
|
||||
CELERY_BROKER_POOL_LIMIT_DEFAULT = 10
|
||||
try:
|
||||
CELERY_BROKER_POOL_LIMIT = int(
|
||||
os.environ.get("CELERY_BROKER_POOL_LIMIT", CELERY_BROKER_POOL_LIMIT_DEFAULT)
|
||||
)
|
||||
except ValueError:
|
||||
CELERY_BROKER_POOL_LIMIT = CELERY_BROKER_POOL_LIMIT_DEFAULT
|
||||
|
||||
#####
|
||||
# Connector Configs
|
||||
#####
|
||||
@@ -247,6 +263,10 @@ JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
|
||||
if ignored_tag
|
||||
]
|
||||
# Maximum size for Jira tickets in bytes (default: 100KB)
|
||||
JIRA_CONNECTOR_MAX_TICKET_SIZE = int(
|
||||
os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024)
|
||||
)
|
||||
|
||||
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
|
||||
|
||||
@@ -270,7 +290,7 @@ ALLOW_SIMULTANEOUS_PRUNING = (
|
||||
os.environ.get("ALLOW_SIMULTANEOUS_PRUNING", "").lower() == "true"
|
||||
)
|
||||
|
||||
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
|
||||
# This is the maximum rate at which documents are queried for a pruning job. 0 disables the limitation.
|
||||
MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE = int(
|
||||
os.environ.get("MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE", 0)
|
||||
)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import platform
|
||||
import socket
|
||||
from enum import auto
|
||||
from enum import Enum
|
||||
|
||||
@@ -34,7 +36,9 @@ POSTGRES_WEB_APP_NAME = "web"
|
||||
POSTGRES_INDEXER_APP_NAME = "indexer"
|
||||
POSTGRES_CELERY_APP_NAME = "celery"
|
||||
POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
|
||||
POSTGRES_CELERY_WORKER_APP_NAME = "celery_worker"
|
||||
POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
|
||||
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
|
||||
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
||||
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
||||
|
||||
@@ -62,6 +66,7 @@ KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings"
|
||||
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
|
||||
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
|
||||
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
@@ -104,6 +109,7 @@ class DocumentSource(str, Enum):
|
||||
R2 = "r2"
|
||||
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||
OCI_STORAGE = "oci_storage"
|
||||
XENFORO = "xenforo"
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
|
||||
|
||||
@@ -186,6 +192,7 @@ class DanswerCeleryQueues:
|
||||
|
||||
|
||||
class DanswerRedisLocks:
|
||||
PRIMARY_WORKER = "da_lock:primary_worker"
|
||||
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
|
||||
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
|
||||
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
|
||||
@@ -198,3 +205,13 @@ class DanswerCeleryPriority(int, Enum):
|
||||
MEDIUM = auto()
|
||||
LOW = auto()
|
||||
LOWEST = auto()
|
||||
|
||||
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3
|
||||
|
||||
if platform.system() == "Darwin":
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPALIVE] = 60 # type: ignore
|
||||
else:
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60 # type: ignore
|
||||
|
||||
32
backend/danswer/connectors/confluence/confluence_utils.py
Normal file
32
backend/danswer/connectors/confluence/confluence_utils.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import bs4
|
||||
|
||||
|
||||
def build_confluence_document_id(base_url: str, content_url: str) -> str:
|
||||
"""For confluence, the document id is the page url for a page based document
|
||||
or the attachment download url for an attachment based document
|
||||
|
||||
Args:
|
||||
base_url (str): The base url of the Confluence instance
|
||||
content_url (str): The url of the page or attachment download url
|
||||
|
||||
Returns:
|
||||
str: The document id
|
||||
"""
|
||||
return f"{base_url}{content_url}"
|
||||
|
||||
|
||||
def get_used_attachments(text: str) -> list[str]:
|
||||
"""Parse a Confluence html page to generate a list of current
|
||||
attachment in used
|
||||
|
||||
Args:
|
||||
text (str): The page content
|
||||
|
||||
Returns:
|
||||
list[str]: List of filenames currently in use by the page text
|
||||
"""
|
||||
files_in_used = []
|
||||
soup = bs4.BeautifulSoup(text, "html.parser")
|
||||
for attachment in soup.findAll("ri:attachment"):
|
||||
files_in_used.append(attachment.attrs["ri:filename"])
|
||||
return files_in_used
|
||||
@@ -22,6 +22,10 @@ from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.confluence.confluence_utils import (
|
||||
build_confluence_document_id,
|
||||
)
|
||||
from danswer.connectors.confluence.confluence_utils import get_used_attachments
|
||||
from danswer.connectors.confluence.rate_limit_handler import (
|
||||
make_confluence_call_handle_rate_limit,
|
||||
)
|
||||
@@ -105,24 +109,6 @@ def parse_html_page(text: str, confluence_client: Confluence) -> str:
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
def get_used_attachments(text: str, confluence_client: Confluence) -> list[str]:
|
||||
"""Parse a Confluence html page to generate a list of current
|
||||
attachment in used
|
||||
|
||||
Args:
|
||||
text (str): The page content
|
||||
confluence_client (Confluence): Confluence client
|
||||
|
||||
Returns:
|
||||
list[str]: List of filename currently in used
|
||||
"""
|
||||
files_in_used = []
|
||||
soup = bs4.BeautifulSoup(text, "html.parser")
|
||||
for attachment in soup.findAll("ri:attachment"):
|
||||
files_in_used.append(attachment.attrs["ri:filename"])
|
||||
return files_in_used
|
||||
|
||||
|
||||
def _comment_dfs(
|
||||
comments_str: str,
|
||||
comment_pages: Collection[dict[str, Any]],
|
||||
@@ -624,13 +610,16 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
page_html = (
|
||||
page["body"].get("storage", page["body"].get("view", {})).get("value")
|
||||
)
|
||||
page_url = self.wiki_base + page["_links"]["webui"]
|
||||
# The url and the id are the same
|
||||
page_url = build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"]
|
||||
)
|
||||
if not page_html:
|
||||
logger.debug("Page is empty, skipping: %s", page_url)
|
||||
continue
|
||||
page_text = parse_html_page(page_html, self.confluence_client)
|
||||
|
||||
files_in_used = get_used_attachments(page_html, self.confluence_client)
|
||||
files_in_used = get_used_attachments(page_html)
|
||||
attachment_text, unused_page_attachments = self._fetch_attachments(
|
||||
self.confluence_client, page_id, files_in_used
|
||||
)
|
||||
@@ -683,8 +672,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
if time_filter and not time_filter(last_updated):
|
||||
continue
|
||||
|
||||
attachment_url = self._attachment_to_download_link(
|
||||
self.confluence_client, attachment
|
||||
# The url and the id are the same
|
||||
attachment_url = build_confluence_document_id(
|
||||
self.wiki_base, attachment["_links"]["download"]
|
||||
)
|
||||
attachment_content = self._attachment_to_content(
|
||||
self.confluence_client, attachment
|
||||
|
||||
@@ -50,6 +50,12 @@ def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
|
||||
pass
|
||||
|
||||
if retry_after is not None:
|
||||
if retry_after > 600:
|
||||
logger.warning(
|
||||
f"Clamping retry_after from {retry_after} to {max_delay} seconds..."
|
||||
)
|
||||
retry_after = max_delay
|
||||
|
||||
logger.warning(
|
||||
f"Rate limit hit. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ from jira.resources import Issue
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
@@ -134,10 +135,18 @@ def fetch_jira_issues_batch(
|
||||
else extract_text_from_adf(jira.raw["fields"]["description"])
|
||||
)
|
||||
comments = _get_comment_strs(jira, comment_email_blacklist)
|
||||
semantic_rep = f"{description}\n" + "\n".join(
|
||||
ticket_content = f"{description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments if comment]
|
||||
)
|
||||
|
||||
# Check ticket size
|
||||
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
|
||||
logger.info(
|
||||
f"Skipping {jira.key} because it exceeds the maximum size of "
|
||||
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
|
||||
)
|
||||
continue
|
||||
|
||||
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
|
||||
|
||||
people = set()
|
||||
@@ -180,7 +189,7 @@ def fetch_jira_issues_batch(
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=page_url,
|
||||
sections=[Section(link=page_url, text=semantic_rep)],
|
||||
sections=[Section(link=page_url, text=ticket_content)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=jira.fields.summary,
|
||||
doc_updated_at=time_str_to_utc(jira.fields.updated),
|
||||
@@ -236,10 +245,12 @@ class JiraConnector(LoadConnector, PollConnector):
|
||||
if self.jira_client is None:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
# Quote the project name to handle reserved words
|
||||
quoted_project = f'"{self.jira_project}"'
|
||||
start_ind = 0
|
||||
while True:
|
||||
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
|
||||
jql=f"project = {self.jira_project}",
|
||||
jql=f"project = {quoted_project}",
|
||||
start_index=start_ind,
|
||||
jira_client=self.jira_client,
|
||||
batch_size=self.batch_size,
|
||||
@@ -267,8 +278,10 @@ class JiraConnector(LoadConnector, PollConnector):
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
|
||||
# Quote the project name to handle reserved words
|
||||
quoted_project = f'"{self.jira_project}"'
|
||||
jql = (
|
||||
f"project = {self.jira_project} AND "
|
||||
f"project = {quoted_project} AND "
|
||||
f"updated >= '{start_date_str}' AND "
|
||||
f"updated <= '{end_date_str}'"
|
||||
)
|
||||
|
||||
@@ -42,6 +42,7 @@ from danswer.connectors.slack.load_connector import SlackLoadConnector
|
||||
from danswer.connectors.teams.connector import TeamsConnector
|
||||
from danswer.connectors.web.connector import WebConnector
|
||||
from danswer.connectors.wikipedia.connector import WikipediaConnector
|
||||
from danswer.connectors.xenforo.connector import XenforoConnector
|
||||
from danswer.connectors.zendesk.connector import ZendeskConnector
|
||||
from danswer.connectors.zulip.connector import ZulipConnector
|
||||
from danswer.db.credentials import backend_update_credential_json
|
||||
@@ -62,6 +63,7 @@ def identify_connector_class(
|
||||
DocumentSource.SLACK: {
|
||||
InputType.LOAD_STATE: SlackLoadConnector,
|
||||
InputType.POLL: SlackPollConnector,
|
||||
InputType.PRUNE: SlackPollConnector,
|
||||
},
|
||||
DocumentSource.GITHUB: GithubConnector,
|
||||
DocumentSource.GMAIL: GmailConnector,
|
||||
@@ -97,6 +99,7 @@ def identify_connector_class(
|
||||
DocumentSource.R2: BlobStorageConnector,
|
||||
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.OCI_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.XENFORO: XenforoConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
|
||||
|
||||
@@ -8,13 +8,12 @@ from typing import cast
|
||||
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.web import SlackResponse
|
||||
|
||||
from danswer.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
@@ -23,9 +22,8 @@ from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.slack.utils import expert_info_from_slack_id
|
||||
from danswer.connectors.slack.utils import get_message_link
|
||||
from danswer.connectors.slack.utils import make_slack_api_call_logged
|
||||
from danswer.connectors.slack.utils import make_slack_api_call_paginated
|
||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from danswer.connectors.slack.utils import make_paginated_slack_api_call_w_retries
|
||||
from danswer.connectors.slack.utils import make_slack_api_call_w_retries
|
||||
from danswer.connectors.slack.utils import SlackTextCleaner
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -38,47 +36,18 @@ MessageType = dict[str, Any]
|
||||
# list of messages in a thread
|
||||
ThreadType = list[MessageType]
|
||||
|
||||
basic_retry_wrapper = retry_builder()
|
||||
|
||||
|
||||
def _make_paginated_slack_api_call(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
return make_slack_api_call_paginated(
|
||||
basic_retry_wrapper(
|
||||
make_slack_api_rate_limited(make_slack_api_call_logged(call))
|
||||
)
|
||||
)(**kwargs)
|
||||
|
||||
|
||||
def _make_slack_api_call(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> SlackResponse:
|
||||
return basic_retry_wrapper(
|
||||
make_slack_api_rate_limited(make_slack_api_call_logged(call))
|
||||
)(**kwargs)
|
||||
|
||||
|
||||
def get_channel_info(client: WebClient, channel_id: str) -> ChannelType:
|
||||
"""Get information about a channel. Needed to convert channel ID to channel name"""
|
||||
return _make_slack_api_call(client.conversations_info, channel=channel_id)[0][
|
||||
"channel"
|
||||
]
|
||||
|
||||
|
||||
def _get_channels(
|
||||
def _collect_paginated_channels(
|
||||
client: WebClient,
|
||||
exclude_archived: bool,
|
||||
get_private: bool,
|
||||
channel_types: list[str],
|
||||
) -> list[ChannelType]:
|
||||
channels: list[dict[str, Any]] = []
|
||||
for result in _make_paginated_slack_api_call(
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
client.conversations_list,
|
||||
exclude_archived=exclude_archived,
|
||||
# also get private channels the bot is added to
|
||||
types=["public_channel", "private_channel"]
|
||||
if get_private
|
||||
else ["public_channel"],
|
||||
types=channel_types,
|
||||
):
|
||||
channels.extend(result["channels"])
|
||||
|
||||
@@ -88,19 +57,38 @@ def _get_channels(
|
||||
def get_channels(
|
||||
client: WebClient,
|
||||
exclude_archived: bool = True,
|
||||
get_public: bool = True,
|
||||
get_private: bool = True,
|
||||
) -> list[ChannelType]:
|
||||
"""Get all channels in the workspace"""
|
||||
channels: list[dict[str, Any]] = []
|
||||
channel_types = []
|
||||
if get_public:
|
||||
channel_types.append("public_channel")
|
||||
if get_private:
|
||||
channel_types.append("private_channel")
|
||||
# try getting private channels as well at first
|
||||
try:
|
||||
return _get_channels(
|
||||
client=client, exclude_archived=exclude_archived, get_private=True
|
||||
channels = _collect_paginated_channels(
|
||||
client=client,
|
||||
exclude_archived=exclude_archived,
|
||||
channel_types=channel_types,
|
||||
)
|
||||
except SlackApiError as e:
|
||||
logger.info(f"Unable to fetch private channels due to - {e}")
|
||||
logger.info("trying again without private channels")
|
||||
if get_public:
|
||||
channel_types = ["public_channel"]
|
||||
else:
|
||||
logger.warning("No channels to fetch")
|
||||
return []
|
||||
channels = _collect_paginated_channels(
|
||||
client=client,
|
||||
exclude_archived=exclude_archived,
|
||||
channel_types=channel_types,
|
||||
)
|
||||
|
||||
return _get_channels(
|
||||
client=client, exclude_archived=exclude_archived, get_private=False
|
||||
)
|
||||
return channels
|
||||
|
||||
|
||||
def get_channel_messages(
|
||||
@@ -112,14 +100,14 @@ def get_channel_messages(
|
||||
"""Get all messages in a channel"""
|
||||
# join so that the bot can access messages
|
||||
if not channel["is_member"]:
|
||||
_make_slack_api_call(
|
||||
make_slack_api_call_w_retries(
|
||||
client.conversations_join,
|
||||
channel=channel["id"],
|
||||
is_private=channel["is_private"],
|
||||
)
|
||||
logger.info(f"Successfully joined '{channel['name']}'")
|
||||
|
||||
for result in _make_paginated_slack_api_call(
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
client.conversations_history,
|
||||
channel=channel["id"],
|
||||
oldest=oldest,
|
||||
@@ -131,7 +119,7 @@ def get_channel_messages(
|
||||
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
|
||||
"""Get all messages in a thread"""
|
||||
threads: list[MessageType] = []
|
||||
for result in _make_paginated_slack_api_call(
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
client.conversations_replies, channel=channel_id, ts=thread_id
|
||||
):
|
||||
threads.extend(result["messages"])
|
||||
@@ -266,7 +254,7 @@ def filter_channels(
|
||||
]
|
||||
|
||||
|
||||
def get_all_docs(
|
||||
def _get_all_docs(
|
||||
client: WebClient,
|
||||
workspace: str,
|
||||
channels: list[str] | None = None,
|
||||
@@ -328,7 +316,44 @@ def get_all_docs(
|
||||
)
|
||||
|
||||
|
||||
class SlackPollConnector(PollConnector):
|
||||
def _get_all_doc_ids(
|
||||
client: WebClient,
|
||||
channels: list[str] | None = None,
|
||||
channel_name_regex_enabled: bool = False,
|
||||
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
|
||||
) -> set[str]:
|
||||
"""
|
||||
Get all document ids in the workspace, channel by channel
|
||||
This is pretty identical to get_all_docs, but it returns a set of ids instead of documents
|
||||
This makes it an order of magnitude faster than get_all_docs
|
||||
"""
|
||||
|
||||
all_channels = get_channels(client)
|
||||
filtered_channels = filter_channels(
|
||||
all_channels, channels, channel_name_regex_enabled
|
||||
)
|
||||
|
||||
all_doc_ids = set()
|
||||
for channel in filtered_channels:
|
||||
channel_message_batches = get_channel_messages(
|
||||
client=client,
|
||||
channel=channel,
|
||||
)
|
||||
|
||||
for message_batch in channel_message_batches:
|
||||
for message in message_batch:
|
||||
if msg_filter_func(message):
|
||||
continue
|
||||
|
||||
# The document id is the channel id and the ts of the first message in the thread
|
||||
# Since we already have the first message of the thread, we dont have to
|
||||
# fetch the thread for id retrieval, saving time and API calls
|
||||
all_doc_ids.add(f"{channel['id']}__{message['ts']}")
|
||||
|
||||
return all_doc_ids
|
||||
|
||||
|
||||
class SlackPollConnector(PollConnector, IdConnector):
|
||||
def __init__(
|
||||
self,
|
||||
workspace: str,
|
||||
@@ -349,6 +374,16 @@ class SlackPollConnector(PollConnector):
|
||||
self.client = WebClient(token=bot_token)
|
||||
return None
|
||||
|
||||
def retrieve_all_source_ids(self) -> set[str]:
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
return _get_all_doc_ids(
|
||||
client=self.client,
|
||||
channels=self.channels,
|
||||
channel_name_regex_enabled=self.channel_regex_enabled,
|
||||
)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
@@ -356,7 +391,7 @@ class SlackPollConnector(PollConnector):
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
documents: list[Document] = []
|
||||
for document in get_all_docs(
|
||||
for document in _get_all_docs(
|
||||
client=self.client,
|
||||
workspace=self.workspace,
|
||||
channels=self.channels,
|
||||
|
||||
@@ -10,11 +10,13 @@ from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.web import SlackResponse
|
||||
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
basic_retry_wrapper = retry_builder()
|
||||
# number of messages we request per page when fetching paginated slack messages
|
||||
_SLACK_LIMIT = 900
|
||||
|
||||
@@ -34,7 +36,7 @@ def get_message_link(
|
||||
)
|
||||
|
||||
|
||||
def make_slack_api_call_logged(
|
||||
def _make_slack_api_call_logged(
|
||||
call: Callable[..., SlackResponse],
|
||||
) -> Callable[..., SlackResponse]:
|
||||
@wraps(call)
|
||||
@@ -47,7 +49,7 @@ def make_slack_api_call_logged(
|
||||
return logged_call
|
||||
|
||||
|
||||
def make_slack_api_call_paginated(
|
||||
def _make_slack_api_call_paginated(
|
||||
call: Callable[..., SlackResponse],
|
||||
) -> Callable[..., Generator[dict[str, Any], None, None]]:
|
||||
"""Wraps calls to slack API so that they automatically handle pagination"""
|
||||
@@ -116,6 +118,24 @@ def make_slack_api_rate_limited(
|
||||
return rate_limited_call
|
||||
|
||||
|
||||
def make_slack_api_call_w_retries(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> SlackResponse:
|
||||
return basic_retry_wrapper(
|
||||
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
|
||||
)(**kwargs)
|
||||
|
||||
|
||||
def make_paginated_slack_api_call_w_retries(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
return _make_slack_api_call_paginated(
|
||||
basic_retry_wrapper(
|
||||
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
|
||||
)
|
||||
)(**kwargs)
|
||||
|
||||
|
||||
def expert_info_from_slack_id(
|
||||
user_id: str | None,
|
||||
client: WebClient,
|
||||
|
||||
0
backend/danswer/connectors/xenforo/__init__.py
Normal file
0
backend/danswer/connectors/xenforo/__init__.py
Normal file
244
backend/danswer/connectors/xenforo/connector.py
Normal file
244
backend/danswer/connectors/xenforo/connector.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""
|
||||
This is the XenforoConnector class. It is used to connect to a Xenforo forum and load or update documents from the forum.
|
||||
|
||||
To use this class, you need to provide the URL of the Xenforo forum board you want to connect to when creating an instance
|
||||
of the class. The URL should be a string that starts with 'http://' or 'https://', followed by the domain name of the
|
||||
forum, followed by the board name. For example:
|
||||
|
||||
base_url = 'https://www.example.com/forum/boards/some-topic/'
|
||||
|
||||
The `load_from_state` method is used to load documents from the forum. It takes an optional `state` parameter, which
|
||||
can be used to specify a state from which to start loading documents.
|
||||
"""
|
||||
import re
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pytz
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4 import Tag
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import datetime_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_title(soup: BeautifulSoup) -> str:
|
||||
el = soup.find("h1", "p-title-value")
|
||||
if not el:
|
||||
return ""
|
||||
title = el.text
|
||||
for char in (";", ":", "!", "*", "/", "\\", "?", '"', "<", ">", "|"):
|
||||
title = title.replace(char, "_")
|
||||
return title
|
||||
|
||||
|
||||
def get_pages(soup: BeautifulSoup, url: str) -> list[str]:
|
||||
page_tags = soup.select("li.pageNav-page")
|
||||
page_numbers = []
|
||||
for button in page_tags:
|
||||
if re.match(r"^\d+$", button.text):
|
||||
page_numbers.append(button.text)
|
||||
|
||||
max_pages = int(max(page_numbers, key=int)) if page_numbers else 1
|
||||
|
||||
all_pages = []
|
||||
for x in range(1, int(max_pages) + 1):
|
||||
all_pages.append(f"{url}page-{x}")
|
||||
return all_pages
|
||||
|
||||
|
||||
def parse_post_date(post_element: BeautifulSoup) -> datetime:
|
||||
el = post_element.find("time")
|
||||
if not isinstance(el, Tag) or "datetime" not in el.attrs:
|
||||
return datetime.utcfromtimestamp(0).replace(tzinfo=timezone.utc)
|
||||
|
||||
date_value = el["datetime"]
|
||||
|
||||
# Ensure date_value is a string (if it's a list, take the first element)
|
||||
if isinstance(date_value, list):
|
||||
date_value = date_value[0]
|
||||
|
||||
post_date = datetime.strptime(date_value, "%Y-%m-%dT%H:%M:%S%z")
|
||||
return datetime_to_utc(post_date)
|
||||
|
||||
|
||||
def scrape_page_posts(
|
||||
soup: BeautifulSoup,
|
||||
page_index: int,
|
||||
url: str,
|
||||
initial_run: bool,
|
||||
start_time: datetime,
|
||||
) -> list:
|
||||
title = get_title(soup)
|
||||
|
||||
documents = []
|
||||
for post in soup.find_all("div", class_="message-inner"):
|
||||
post_date = parse_post_date(post)
|
||||
if initial_run or post_date > start_time:
|
||||
el = post.find("div", class_="bbWrapper")
|
||||
if not el:
|
||||
continue
|
||||
post_text = el.get_text(strip=True) + "\n"
|
||||
author_tag = post.find("a", class_="username")
|
||||
if author_tag is None:
|
||||
author_tag = post.find("span", class_="username")
|
||||
author = author_tag.get_text(strip=True) if author_tag else "Deleted author"
|
||||
formatted_time = post_date.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# TODO: if a caller calls this for each page of a thread, it may see the
|
||||
# same post multiple times if there is a sticky post
|
||||
# that appears on each page of a thread.
|
||||
# it's important to generate unique doc id's, so page index is part of the
|
||||
# id. We may want to de-dupe this stuff inside the indexing service.
|
||||
document = Document(
|
||||
id=f"{DocumentSource.XENFORO.value}_{title}_{page_index}_{formatted_time}",
|
||||
sections=[Section(link=url, text=post_text)],
|
||||
title=title,
|
||||
source=DocumentSource.XENFORO,
|
||||
semantic_identifier=title,
|
||||
primary_owners=[BasicExpertInfo(display_name=author)],
|
||||
metadata={
|
||||
"type": "post",
|
||||
"author": author,
|
||||
"time": formatted_time,
|
||||
},
|
||||
doc_updated_at=post_date,
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
return documents
|
||||
|
||||
|
||||
class XenforoConnector(LoadConnector):
|
||||
# Class variable to track if the connector has been run before
|
||||
has_been_run_before = False
|
||||
|
||||
def __init__(self, base_url: str) -> None:
|
||||
self.base_url = base_url
|
||||
self.initial_run = not XenforoConnector.has_been_run_before
|
||||
self.start = datetime.utcnow().replace(tzinfo=pytz.utc) - timedelta(days=1)
|
||||
self.cookies: dict[str, str] = {}
|
||||
# mimic user browser to avoid being blocked by the website (see: https://www.useragents.me/)
|
||||
self.headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/121.0.0.0 Safari/537.36"
|
||||
}
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if credentials:
|
||||
logger.warning("Unexpected credentials provided for Xenforo Connector")
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
# Standardize URL to always end in /.
|
||||
if self.base_url[-1] != "/":
|
||||
self.base_url += "/"
|
||||
|
||||
# Remove all extra parameters from the end such as page, post.
|
||||
matches = ("threads/", "boards/", "forums/")
|
||||
for each in matches:
|
||||
if each in self.base_url:
|
||||
try:
|
||||
self.base_url = self.base_url[
|
||||
0 : self.base_url.index(
|
||||
"/", self.base_url.index(each) + len(each)
|
||||
)
|
||||
+ 1
|
||||
]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
all_threads = []
|
||||
|
||||
# If the URL contains "boards/" or "forums/", find all threads.
|
||||
if "boards/" in self.base_url or "forums/" in self.base_url:
|
||||
pages = get_pages(self.requestsite(self.base_url), self.base_url)
|
||||
|
||||
# Get all pages on thread_list_page
|
||||
for pre_count, thread_list_page in enumerate(pages, start=1):
|
||||
logger.info(
|
||||
f"Getting pages from thread_list_page.. Current: {pre_count}/{len(pages)}\r"
|
||||
)
|
||||
all_threads += self.get_threads(thread_list_page)
|
||||
# If the URL contains "threads/", add the thread to the list.
|
||||
elif "threads/" in self.base_url:
|
||||
all_threads.append(self.base_url)
|
||||
|
||||
# Process all threads
|
||||
for thread_count, thread_url in enumerate(all_threads, start=1):
|
||||
soup = self.requestsite(thread_url)
|
||||
if soup is None:
|
||||
logger.error(f"Failed to load page: {self.base_url}")
|
||||
continue
|
||||
pages = get_pages(soup, thread_url)
|
||||
# Getting all pages for all threads
|
||||
for page_index, page in enumerate(pages, start=1):
|
||||
logger.info(
|
||||
f"Progress: Page {page_index}/{len(pages)} - Thread {thread_count}/{len(all_threads)}\r"
|
||||
)
|
||||
soup_page = self.requestsite(page)
|
||||
doc_batch.extend(
|
||||
scrape_page_posts(
|
||||
soup_page, page_index, thread_url, self.initial_run, self.start
|
||||
)
|
||||
)
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
# Mark the initial run finished after all threads and pages have been processed
|
||||
XenforoConnector.has_been_run_before = True
|
||||
|
||||
def get_threads(self, url: str) -> list[str]:
|
||||
soup = self.requestsite(url)
|
||||
thread_tags = soup.find_all(class_="structItem-title")
|
||||
base_url = "{uri.scheme}://{uri.netloc}".format(uri=urlparse(url))
|
||||
threads = []
|
||||
for x in thread_tags:
|
||||
y = x.find_all(href=True)
|
||||
for element in y:
|
||||
link = element["href"]
|
||||
if "threads/" in link:
|
||||
stripped = link[0 : link.rfind("/") + 1]
|
||||
if base_url + stripped not in threads:
|
||||
threads.append(base_url + stripped)
|
||||
return threads
|
||||
|
||||
def requestsite(self, url: str) -> BeautifulSoup:
|
||||
try:
|
||||
response = requests.get(
|
||||
url, cookies=self.cookies, headers=self.headers, timeout=10
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
f"<{url}> Request Error: {response.status_code} - {response.reason}"
|
||||
)
|
||||
return BeautifulSoup(response.text, "html.parser")
|
||||
except TimeoutError:
|
||||
logger.error("Timed out Error.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error on {url}")
|
||||
logger.exception(e)
|
||||
return BeautifulSoup("", "html.parser")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = XenforoConnector(
|
||||
# base_url="https://cassiopaea.org/forum/threads/how-to-change-your-emotional-state.41381/"
|
||||
base_url="https://xenforo.com/community/threads/whats-new-with-enhanced-search-resource-manager-and-media-gallery-in-xenforo-2-3.220935/"
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
@@ -26,9 +26,7 @@ from danswer.db.models import UserRole
|
||||
from danswer.server.models import StatusResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit
|
||||
from ee.danswer.external_permissions.permission_sync_function_map import (
|
||||
check_if_valid_sync_source,
|
||||
)
|
||||
from ee.danswer.external_permissions.sync_params import check_if_valid_sync_source
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -104,6 +104,18 @@ def construct_document_select_for_connector_credential_pair(
|
||||
return stmt
|
||||
|
||||
|
||||
def get_document_ids_for_connector_credential_pair(
|
||||
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
|
||||
) -> list[str]:
|
||||
doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
)
|
||||
return list(db_session.execute(doc_ids_stmt).scalars().all())
|
||||
|
||||
|
||||
def get_documents_for_connector_credential_pair(
|
||||
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
|
||||
) -> Sequence[DbDocument]:
|
||||
@@ -120,8 +132,8 @@ def get_documents_for_connector_credential_pair(
|
||||
|
||||
|
||||
def get_documents_by_ids(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
document_ids: list[str],
|
||||
) -> list[DbDocument]:
|
||||
stmt = select(DbDocument).where(DbDocument.id.in_(document_ids))
|
||||
documents = db_session.execute(stmt).scalars().all()
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import contextlib
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import ContextManager
|
||||
|
||||
from sqlalchemy import event
|
||||
@@ -32,14 +34,9 @@ logger = setup_logger()
|
||||
SYNC_DB_API = "psycopg2"
|
||||
ASYNC_DB_API = "asyncpg"
|
||||
|
||||
POSTGRES_APP_NAME = (
|
||||
POSTGRES_UNKNOWN_APP_NAME # helps to diagnose open connections in postgres
|
||||
)
|
||||
|
||||
# global so we don't create more than one engine per process
|
||||
# outside of being best practice, this is needed so we can properly pool
|
||||
# connections and not create a new pool on every request
|
||||
_SYNC_ENGINE: Engine | None = None
|
||||
_ASYNC_ENGINE: AsyncEngine | None = None
|
||||
|
||||
SessionFactory: sessionmaker[Session] | None = None
|
||||
@@ -108,6 +105,67 @@ def get_db_current_time(db_session: Session) -> datetime:
|
||||
return result
|
||||
|
||||
|
||||
class SqlEngine:
|
||||
"""Class to manage a global sql alchemy engine (needed for proper resource control)
|
||||
Will eventually subsume most of the standalone functions in this file.
|
||||
Sync only for now"""
|
||||
|
||||
_engine: Engine | None = None
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
|
||||
|
||||
# Default parameters for engine creation
|
||||
DEFAULT_ENGINE_KWARGS = {
|
||||
"pool_size": 40,
|
||||
"max_overflow": 10,
|
||||
"pool_pre_ping": POSTGRES_POOL_PRE_PING,
|
||||
"pool_recycle": POSTGRES_POOL_RECYCLE,
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
|
||||
"""Private helper method to create and return an Engine."""
|
||||
connection_string = build_connection_string(
|
||||
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync"
|
||||
)
|
||||
merged_kwargs = {**cls.DEFAULT_ENGINE_KWARGS, **engine_kwargs}
|
||||
return create_engine(connection_string, **merged_kwargs)
|
||||
|
||||
@classmethod
|
||||
def init_engine(cls, **engine_kwargs: Any) -> None:
|
||||
"""Allow the caller to init the engine with extra params. Different clients
|
||||
such as the API server and different celery workers and tasks
|
||||
need different settings."""
|
||||
with cls._lock:
|
||||
if not cls._engine:
|
||||
cls._engine = cls._init_engine(**engine_kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_engine(cls) -> Engine:
|
||||
"""Gets the sql alchemy engine. Will init a default engine if init hasn't
|
||||
already been called. You probably want to init first!"""
|
||||
if not cls._engine:
|
||||
with cls._lock:
|
||||
if not cls._engine:
|
||||
cls._engine = cls._init_engine()
|
||||
return cls._engine
|
||||
|
||||
@classmethod
|
||||
def set_app_name(cls, app_name: str) -> None:
|
||||
"""Class method to set the app name."""
|
||||
cls._app_name = app_name
|
||||
|
||||
@classmethod
|
||||
def get_app_name(cls) -> str:
|
||||
"""Class method to get current app name."""
|
||||
if not cls._app_name:
|
||||
return ""
|
||||
return cls._app_name
|
||||
|
||||
|
||||
def build_connection_string(
|
||||
*,
|
||||
db_api: str = ASYNC_DB_API,
|
||||
@@ -125,24 +183,11 @@ def build_connection_string(
|
||||
|
||||
|
||||
def init_sqlalchemy_engine(app_name: str) -> None:
|
||||
global POSTGRES_APP_NAME
|
||||
POSTGRES_APP_NAME = app_name
|
||||
SqlEngine.set_app_name(app_name)
|
||||
|
||||
|
||||
def get_sqlalchemy_engine() -> Engine:
|
||||
global _SYNC_ENGINE
|
||||
if _SYNC_ENGINE is None:
|
||||
connection_string = build_connection_string(
|
||||
db_api=SYNC_DB_API, app_name=POSTGRES_APP_NAME + "_sync"
|
||||
)
|
||||
_SYNC_ENGINE = create_engine(
|
||||
connection_string,
|
||||
pool_size=40,
|
||||
max_overflow=10,
|
||||
pool_pre_ping=POSTGRES_POOL_PRE_PING,
|
||||
pool_recycle=POSTGRES_POOL_RECYCLE,
|
||||
)
|
||||
return _SYNC_ENGINE
|
||||
return SqlEngine.get_engine()
|
||||
|
||||
|
||||
def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
@@ -154,7 +199,9 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
_ASYNC_ENGINE = create_async_engine(
|
||||
connection_string,
|
||||
connect_args={
|
||||
"server_settings": {"application_name": POSTGRES_APP_NAME + "_async"}
|
||||
"server_settings": {
|
||||
"application_name": SqlEngine.get_app_name() + "_async"
|
||||
}
|
||||
},
|
||||
pool_size=40,
|
||||
max_overflow=10,
|
||||
|
||||
@@ -64,19 +64,12 @@ def upsert_cloud_embedding_provider(
|
||||
def upsert_llm_provider(
|
||||
llm_provider: LLMProviderUpsertRequest,
|
||||
db_session: Session,
|
||||
is_creation: bool = True,
|
||||
) -> FullLLMProvider:
|
||||
existing_llm_provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
|
||||
)
|
||||
if existing_llm_provider and is_creation:
|
||||
raise ValueError(f"LLM Provider with name {llm_provider.name} already exists")
|
||||
|
||||
if not existing_llm_provider:
|
||||
if not is_creation:
|
||||
raise ValueError(
|
||||
f"LLM Provider with name {llm_provider.name} does not exist"
|
||||
)
|
||||
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
|
||||
db_session.add(existing_llm_provider)
|
||||
|
||||
|
||||
@@ -1725,7 +1725,9 @@ class User__ExternalUserGroupId(Base):
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
|
||||
# These group ids have been prefixed by the source type
|
||||
external_user_group_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
cc_pair_id: Mapped[int] = mapped_column(ForeignKey("connector_credential_pair.id"))
|
||||
cc_pair_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("connector_credential_pair.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class UsageReport(Base):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import or_
|
||||
@@ -107,12 +108,14 @@ def create_or_add_document_tag_list(
|
||||
return all_tags
|
||||
|
||||
|
||||
def get_tags_by_value_prefix_for_source_types(
|
||||
def find_tags(
|
||||
tag_key_prefix: str | None,
|
||||
tag_value_prefix: str | None,
|
||||
sources: list[DocumentSource] | None,
|
||||
limit: int | None,
|
||||
db_session: Session,
|
||||
# if set, both tag_key_prefix and tag_value_prefix must be a match
|
||||
require_both_to_match: bool = False,
|
||||
) -> list[Tag]:
|
||||
query = select(Tag)
|
||||
|
||||
@@ -122,7 +125,11 @@ def get_tags_by_value_prefix_for_source_types(
|
||||
conditions.append(Tag.tag_key.ilike(f"{tag_key_prefix}%"))
|
||||
if tag_value_prefix:
|
||||
conditions.append(Tag.tag_value.ilike(f"{tag_value_prefix}%"))
|
||||
query = query.where(or_(*conditions))
|
||||
|
||||
final_prefix_condition = (
|
||||
and_(*conditions) if require_both_to_match else or_(*conditions)
|
||||
)
|
||||
query = query.where(final_prefix_condition)
|
||||
|
||||
if sources:
|
||||
query = query.where(Tag.source.in_(sources))
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.vespa.index import VespaIndex
|
||||
|
||||
@@ -13,3 +16,14 @@ def get_default_document_index(
|
||||
return VespaIndex(
|
||||
index_name=primary_index_name, secondary_index_name=secondary_index_name
|
||||
)
|
||||
|
||||
|
||||
def get_current_primary_default_document_index(db_session: Session) -> DocumentIndex:
|
||||
"""
|
||||
TODO: Use redis to cache this or something
|
||||
"""
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
return get_default_document_index(
|
||||
primary_index_name=search_settings.index_name,
|
||||
secondary_index_name=None,
|
||||
)
|
||||
|
||||
@@ -156,6 +156,16 @@ class Deletable(abc.ABC):
|
||||
Class must implement the ability to delete document by their unique document ids.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_single(self, doc_id: str) -> None:
|
||||
"""
|
||||
Given a single document id, hard delete it from the document index
|
||||
|
||||
Parameters:
|
||||
- doc_id: document id as specified by the connector
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete(self, doc_ids: list[str]) -> None:
|
||||
"""
|
||||
|
||||
@@ -13,6 +13,7 @@ from typing import cast
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
|
||||
from danswer.configs.chat_configs import DOC_TIME_DECAY
|
||||
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
@@ -479,6 +480,66 @@ class VespaIndex(DocumentIndex):
|
||||
document_ids=doc_ids, index_name=index_name, http_client=http_client
|
||||
)
|
||||
|
||||
def delete_single(self, doc_id: str) -> None:
|
||||
"""Possibly faster overall than the delete method due to using a single
|
||||
delete call with a selection query."""
|
||||
|
||||
# Vespa deletion is poorly documented ... luckily we found this
|
||||
# https://docs.vespa.ai/en/operations/batch-delete.html#example
|
||||
|
||||
doc_id = replace_invalid_doc_id_characters(doc_id)
|
||||
|
||||
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficial for
|
||||
# indexing / updates / deletes since we have to make a large volume of requests.
|
||||
index_names = [self.index_name]
|
||||
if self.secondary_index_name:
|
||||
index_names.append(self.secondary_index_name)
|
||||
|
||||
with httpx.Client(http2=True) as http_client:
|
||||
for index_name in index_names:
|
||||
params = httpx.QueryParams(
|
||||
{
|
||||
"selection": f"{index_name}.document_id=='{doc_id}'",
|
||||
"cluster": DOCUMENT_INDEX_NAME,
|
||||
}
|
||||
)
|
||||
|
||||
total_chunks_deleted = 0
|
||||
while True:
|
||||
try:
|
||||
resp = http_client.delete(
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}",
|
||||
params=params,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
f"Failed to delete chunk, details: {e.response.text}"
|
||||
)
|
||||
raise
|
||||
|
||||
resp_data = resp.json()
|
||||
|
||||
if "documentCount" in resp_data:
|
||||
chunks_deleted = resp_data["documentCount"]
|
||||
total_chunks_deleted += chunks_deleted
|
||||
|
||||
# Check for continuation token to handle pagination
|
||||
if "continuation" not in resp_data:
|
||||
break # Exit loop if no continuation token
|
||||
|
||||
if not resp_data["continuation"]:
|
||||
break # Exit loop if continuation token is empty
|
||||
|
||||
params = params.set("continuation", resp_data["continuation"])
|
||||
|
||||
logger.debug(
|
||||
f"VespaIndex.delete_single: "
|
||||
f"index={index_name} "
|
||||
f"doc={doc_id} "
|
||||
f"chunks_deleted={total_chunks_deleted}"
|
||||
)
|
||||
|
||||
def id_based_retrieval(
|
||||
self,
|
||||
chunk_requests: list[VespaChunkRequest],
|
||||
|
||||
@@ -10,6 +10,7 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_metadata_keys_to_ignore,
|
||||
)
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.indexing.indexing_heartbeat import Heartbeat
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.natural_language_processing.utils import BaseTokenizer
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -123,6 +124,7 @@ class Chunker:
|
||||
chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
chunk_overlap: int = CHUNK_OVERLAP,
|
||||
mini_chunk_size: int = MINI_CHUNK_SIZE,
|
||||
heartbeat: Heartbeat | None = None,
|
||||
) -> None:
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
@@ -131,6 +133,7 @@ class Chunker:
|
||||
self.enable_multipass = enable_multipass
|
||||
self.enable_large_chunks = enable_large_chunks
|
||||
self.tokenizer = tokenizer
|
||||
self.heartbeat = heartbeat
|
||||
|
||||
self.blurb_splitter = SentenceSplitter(
|
||||
tokenizer=tokenizer.tokenize,
|
||||
@@ -255,7 +258,7 @@ class Chunker:
|
||||
# If the chunk does not have any useable content, it will not be indexed
|
||||
return chunks
|
||||
|
||||
def chunk(self, document: Document) -> list[DocAwareChunk]:
|
||||
def _handle_single_document(self, document: Document) -> list[DocAwareChunk]:
|
||||
# Specifically for reproducing an issue with gmail
|
||||
if document.source == DocumentSource.GMAIL:
|
||||
logger.debug(f"Chunking {document.semantic_identifier}")
|
||||
@@ -302,3 +305,13 @@ class Chunker:
|
||||
normal_chunks.extend(large_chunks)
|
||||
|
||||
return normal_chunks
|
||||
|
||||
def chunk(self, documents: list[Document]) -> list[DocAwareChunk]:
|
||||
final_chunks: list[DocAwareChunk] = []
|
||||
for document in documents:
|
||||
final_chunks.extend(self._handle_single_document(document))
|
||||
|
||||
if self.heartbeat:
|
||||
self.heartbeat.heartbeat()
|
||||
|
||||
return final_chunks
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.indexing.indexing_heartbeat import Heartbeat
|
||||
from danswer.indexing.models import ChunkEmbedding
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import IndexChunk
|
||||
@@ -24,6 +20,9 @@ logger = setup_logger()
|
||||
|
||||
|
||||
class IndexingEmbedder(ABC):
|
||||
"""Converts chunks into chunks with embeddings. Note that one chunk may have
|
||||
multiple embeddings associated with it."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
@@ -33,6 +32,7 @@ class IndexingEmbedder(ABC):
|
||||
provider_type: EmbeddingProvider | None,
|
||||
api_key: str | None,
|
||||
api_url: str | None,
|
||||
heartbeat: Heartbeat | None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.normalize = normalize
|
||||
@@ -54,6 +54,7 @@ class IndexingEmbedder(ABC):
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
retrim_content=True,
|
||||
heartbeat=heartbeat,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@@ -74,6 +75,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
provider_type: EmbeddingProvider | None = None,
|
||||
api_key: str | None = None,
|
||||
api_url: str | None = None,
|
||||
heartbeat: Heartbeat | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
model_name,
|
||||
@@ -83,6 +85,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
provider_type,
|
||||
api_key,
|
||||
api_url,
|
||||
heartbeat,
|
||||
)
|
||||
|
||||
@log_function_time()
|
||||
@@ -166,7 +169,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
title_embed_dict[title] = title_embedding
|
||||
|
||||
new_embedded_chunk = IndexChunk(
|
||||
**chunk.dict(),
|
||||
**chunk.model_dump(),
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=chunk_embeddings[0],
|
||||
mini_chunk_embeddings=chunk_embeddings[1:],
|
||||
@@ -180,7 +183,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
|
||||
@classmethod
|
||||
def from_db_search_settings(
|
||||
cls, search_settings: SearchSettings
|
||||
cls, search_settings: SearchSettings, heartbeat: Heartbeat | None = None
|
||||
) -> "DefaultIndexingEmbedder":
|
||||
return cls(
|
||||
model_name=search_settings.model_name,
|
||||
@@ -190,28 +193,5 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
provider_type=search_settings.provider_type,
|
||||
api_key=search_settings.api_key,
|
||||
api_url=search_settings.api_url,
|
||||
heartbeat=heartbeat,
|
||||
)
|
||||
|
||||
|
||||
def get_embedding_model_from_search_settings(
|
||||
db_session: Session, index_model_status: IndexModelStatus = IndexModelStatus.PRESENT
|
||||
) -> IndexingEmbedder:
|
||||
search_settings: SearchSettings | None
|
||||
if index_model_status == IndexModelStatus.PRESENT:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
elif index_model_status == IndexModelStatus.FUTURE:
|
||||
search_settings = get_secondary_search_settings(db_session)
|
||||
if not search_settings:
|
||||
raise RuntimeError("No secondary index configured")
|
||||
else:
|
||||
raise RuntimeError("Not supporting embedding model rollbacks")
|
||||
|
||||
return DefaultIndexingEmbedder(
|
||||
model_name=search_settings.model_name,
|
||||
normalize=search_settings.normalize,
|
||||
query_prefix=search_settings.query_prefix,
|
||||
passage_prefix=search_settings.passage_prefix,
|
||||
provider_type=search_settings.provider_type,
|
||||
api_key=search_settings.api_key,
|
||||
api_url=search_settings.api_url,
|
||||
)
|
||||
|
||||
41
backend/danswer/indexing/indexing_heartbeat.py
Normal file
41
backend/danswer/indexing/indexing_heartbeat.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import abc
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class Heartbeat(abc.ABC):
|
||||
"""Useful for any long-running work that goes through a bunch of items
|
||||
and needs to occasionally give updates on progress.
|
||||
e.g. chunking, embedding, updating vespa, etc."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def heartbeat(self, metadata: Any = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IndexingHeartbeat(Heartbeat):
|
||||
def __init__(self, index_attempt_id: int, db_session: Session, freq: int):
|
||||
self.cnt = 0
|
||||
|
||||
self.index_attempt_id = index_attempt_id
|
||||
self.db_session = db_session
|
||||
self.freq = freq
|
||||
|
||||
def heartbeat(self, metadata: Any = None) -> None:
|
||||
self.cnt += 1
|
||||
if self.cnt % self.freq == 0:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=self.db_session, index_attempt_id=self.index_attempt_id
|
||||
)
|
||||
if index_attempt:
|
||||
index_attempt.time_updated = func.now()
|
||||
self.db_session.commit()
|
||||
else:
|
||||
logger.error("Index attempt not found, this should not happen!")
|
||||
@@ -31,6 +31,7 @@ from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import DocumentMetadata
|
||||
from danswer.indexing.chunker import Chunker
|
||||
from danswer.indexing.embedder import IndexingEmbedder
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -220,8 +221,8 @@ def index_doc_batch_prepare(
|
||||
|
||||
document_ids = [document.id for document in documents]
|
||||
db_docs: list[DBDocument] = get_documents_by_ids(
|
||||
document_ids=document_ids,
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
# Skip indexing docs that don't have a newer updated at
|
||||
@@ -283,18 +284,10 @@ def index_doc_batch(
|
||||
return 0, 0
|
||||
|
||||
logger.debug("Starting chunking")
|
||||
chunks: list[DocAwareChunk] = []
|
||||
for document in ctx.updatable_docs:
|
||||
chunks.extend(chunker.chunk(document=document))
|
||||
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
|
||||
|
||||
logger.debug("Starting embedding")
|
||||
chunks_with_embeddings = (
|
||||
embedder.embed_chunks(
|
||||
chunks=chunks,
|
||||
)
|
||||
if chunks
|
||||
else []
|
||||
)
|
||||
chunks_with_embeddings = embedder.embed_chunks(chunks) if chunks else []
|
||||
|
||||
updatable_ids = [doc.id for doc in ctx.updatable_docs]
|
||||
|
||||
@@ -406,6 +399,13 @@ def build_indexing_pipeline(
|
||||
tokenizer=embedder.embedding_model.tokenizer,
|
||||
enable_multipass=multipass,
|
||||
enable_large_chunks=enable_large_chunks,
|
||||
# after every doc, update status in case there are a bunch of
|
||||
# really long docs
|
||||
heartbeat=IndexingHeartbeat(
|
||||
index_attempt_id=attempt_id, db_session=db_session, freq=1
|
||||
)
|
||||
if attempt_id
|
||||
else None,
|
||||
)
|
||||
|
||||
return partial(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import itertools
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
@@ -315,7 +316,9 @@ class Answer:
|
||||
|
||||
yield from self._process_llm_stream(
|
||||
prompt=prompt,
|
||||
tools=[tool.tool_definition() for tool in self.tools],
|
||||
# as of now, we don't support multiple tool calls in sequence, which is why
|
||||
# we don't need to pass this in here
|
||||
# tools=[tool.tool_definition() for tool in self.tools],
|
||||
)
|
||||
|
||||
return
|
||||
@@ -554,8 +557,7 @@ class Answer:
|
||||
|
||||
def _stream() -> Iterator[str]:
|
||||
nonlocal stream_stop_info
|
||||
yield cast(str, message)
|
||||
for item in stream:
|
||||
for item in itertools.chain([message], stream):
|
||||
if isinstance(item, StreamStopInfo):
|
||||
stream_stop_info = item
|
||||
return
|
||||
|
||||
@@ -16,6 +16,7 @@ from danswer.configs.model_configs import (
|
||||
)
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.indexing.indexing_heartbeat import Heartbeat
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -95,6 +96,7 @@ class EmbeddingModel:
|
||||
api_url: str | None,
|
||||
provider_type: EmbeddingProvider | None,
|
||||
retrim_content: bool = False,
|
||||
heartbeat: Heartbeat | None = None,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.provider_type = provider_type
|
||||
@@ -107,6 +109,7 @@ class EmbeddingModel:
|
||||
self.tokenizer = get_tokenizer(
|
||||
model_name=model_name, provider_type=provider_type
|
||||
)
|
||||
self.heartbeat = heartbeat
|
||||
|
||||
model_server_url = build_model_server_url(server_host, server_port)
|
||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||
@@ -166,6 +169,9 @@ class EmbeddingModel:
|
||||
|
||||
response = self._make_model_server_request(embed_request)
|
||||
embeddings.extend(response.embeddings)
|
||||
|
||||
if self.heartbeat:
|
||||
self.heartbeat.heartbeat()
|
||||
return embeddings
|
||||
|
||||
def encode(
|
||||
|
||||
@@ -3,23 +3,23 @@ from typing import Optional
|
||||
|
||||
import redis
|
||||
from redis.client import Redis
|
||||
from redis.connection import ConnectionPool
|
||||
|
||||
from danswer.configs.app_configs import REDIS_DB_NUMBER
|
||||
from danswer.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
|
||||
from danswer.configs.app_configs import REDIS_HOST
|
||||
from danswer.configs.app_configs import REDIS_PASSWORD
|
||||
from danswer.configs.app_configs import REDIS_POOL_MAX_CONNECTIONS
|
||||
from danswer.configs.app_configs import REDIS_PORT
|
||||
from danswer.configs.app_configs import REDIS_SSL
|
||||
from danswer.configs.app_configs import REDIS_SSL_CA_CERTS
|
||||
from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
|
||||
|
||||
REDIS_POOL_MAX_CONNECTIONS = 10
|
||||
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
|
||||
|
||||
|
||||
class RedisPool:
|
||||
_instance: Optional["RedisPool"] = None
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
_pool: ConnectionPool
|
||||
_pool: redis.BlockingConnectionPool
|
||||
|
||||
def __new__(cls) -> "RedisPool":
|
||||
if not cls._instance:
|
||||
@@ -42,30 +42,42 @@ class RedisPool:
|
||||
db: int = REDIS_DB_NUMBER,
|
||||
password: str = REDIS_PASSWORD,
|
||||
max_connections: int = REDIS_POOL_MAX_CONNECTIONS,
|
||||
ssl_ca_certs: str = REDIS_SSL_CA_CERTS,
|
||||
ssl_ca_certs: str | None = REDIS_SSL_CA_CERTS,
|
||||
ssl_cert_reqs: str = REDIS_SSL_CERT_REQS,
|
||||
ssl: bool = False,
|
||||
) -> redis.ConnectionPool:
|
||||
) -> redis.BlockingConnectionPool:
|
||||
"""We use BlockingConnectionPool because it will block and wait for a connection
|
||||
rather than error if max_connections is reached. This is far more deterministic
|
||||
behavior and aligned with how we want to use Redis."""
|
||||
|
||||
# Using ConnectionPool is not well documented.
|
||||
# Useful examples: https://github.com/redis/redis-py/issues/780
|
||||
if ssl:
|
||||
return redis.ConnectionPool(
|
||||
return redis.BlockingConnectionPool(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
max_connections=max_connections,
|
||||
timeout=None,
|
||||
health_check_interval=REDIS_HEALTH_CHECK_INTERVAL,
|
||||
socket_keepalive=True,
|
||||
socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS,
|
||||
connection_class=redis.SSLConnection,
|
||||
ssl_ca_certs=ssl_ca_certs,
|
||||
ssl_cert_reqs=ssl_cert_reqs,
|
||||
)
|
||||
|
||||
return redis.ConnectionPool(
|
||||
return redis.BlockingConnectionPool(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
max_connections=max_connections,
|
||||
timeout=None,
|
||||
health_check_interval=REDIS_HEALTH_CHECK_INTERVAL,
|
||||
socket_keepalive=True,
|
||||
socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import math
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
@@ -10,6 +11,8 @@ from sqlalchemy.orm import Session
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
|
||||
from danswer.background.celery.celery_utils import skip_cc_pair_pruning_by_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.db.connector_credential_pair import add_credential_to_connector
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import remove_credential_from_connector
|
||||
@@ -26,7 +29,9 @@ from danswer.db.index_attempt import count_index_attempts_for_connector
|
||||
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
|
||||
from danswer.db.models import User
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.server.documents.models import CCPairFullInfo
|
||||
from danswer.server.documents.models import CCPairPruningTask
|
||||
from danswer.server.documents.models import CCStatusUpdateRequest
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.server.documents.models import ConnectorCredentialPairMetadata
|
||||
@@ -36,7 +41,6 @@ from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.user_group import validate_user_creation_permissions
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
|
||||
|
||||
@@ -190,6 +194,92 @@ def update_cc_pair_name(
|
||||
raise HTTPException(status_code=400, detail="Name must be unique")
|
||||
|
||||
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/prune")
|
||||
def get_cc_pair_latest_prune(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CCPairPruningTask:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=False,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
# look up the last prune task for this connector (if it exists)
|
||||
pruning_task_name = name_cc_prune_task(
|
||||
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
|
||||
)
|
||||
last_pruning_task = get_latest_task(pruning_task_name, db_session)
|
||||
if not last_pruning_task:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
detail="No pruning task found.",
|
||||
)
|
||||
|
||||
return CCPairPruningTask(
|
||||
id=last_pruning_task.task_id,
|
||||
name=last_pruning_task.task_name,
|
||||
status=last_pruning_task.status,
|
||||
start_time=last_pruning_task.start_time,
|
||||
register_time=last_pruning_task.register_time,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/cc-pair/{cc_pair_id}/prune")
|
||||
def prune_cc_pair(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse[list[int]]:
|
||||
# avoiding circular refs
|
||||
from danswer.background.celery.tasks.pruning.tasks import prune_documents_task
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=False,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
pruning_task_name = name_cc_prune_task(
|
||||
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
|
||||
)
|
||||
last_pruning_task = get_latest_task(pruning_task_name, db_session)
|
||||
if skip_cc_pair_pruning_by_task(
|
||||
last_pruning_task,
|
||||
db_session=db_session,
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.CONFLICT,
|
||||
detail="Pruning task already in progress.",
|
||||
)
|
||||
|
||||
logger.info(f"Pruning the {cc_pair.connector.name} connector.")
|
||||
prune_documents_task.apply_async(
|
||||
kwargs=dict(
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
)
|
||||
)
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message="Successfully created the pruning task.",
|
||||
)
|
||||
|
||||
|
||||
@router.put("/connector/{connector_id}/credential/{credential_id}")
|
||||
def associate_credential_to_connector(
|
||||
connector_id: int,
|
||||
|
||||
@@ -268,6 +268,14 @@ class CCPairFullInfo(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class CCPairPruningTask(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
status: TaskStatus
|
||||
start_time: datetime | None
|
||||
register_time: datetime | None
|
||||
|
||||
|
||||
class FailedConnectorIndexingStatus(BaseModel):
|
||||
"""Simplified version of ConnectorIndexingStatus for failed indexing attempts"""
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DocumentSource
|
||||
@@ -146,10 +147,6 @@ def create_deletion_attempt_for_connector_id(
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
from danswer.background.celery.celery_app import (
|
||||
check_for_connector_deletion_task,
|
||||
)
|
||||
|
||||
connector_id = connector_credential_pair_identifier.connector_id
|
||||
credential_id = connector_credential_pair_identifier.credential_id
|
||||
|
||||
@@ -193,8 +190,11 @@ def create_deletion_attempt_for_connector_id(
|
||||
status=ConnectorCredentialPairStatus.DELETING,
|
||||
)
|
||||
|
||||
# run the beat task to pick up this deletion early
|
||||
check_for_connector_deletion_task.apply_async(
|
||||
db_session.commit()
|
||||
|
||||
# run the beat task to pick up this deletion from the db immediately
|
||||
celery_app.send_task(
|
||||
"check_for_connector_deletion_task",
|
||||
priority=DanswerCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.llm import fetch_provider
|
||||
from danswer.db.llm import remove_llm_provider
|
||||
from danswer.db.llm import update_default_provider
|
||||
from danswer.db.llm import upsert_llm_provider
|
||||
@@ -124,17 +125,26 @@ def list_llm_providers(
|
||||
def put_llm_provider(
|
||||
llm_provider: LLMProviderUpsertRequest,
|
||||
is_creation: bool = Query(
|
||||
True,
|
||||
False,
|
||||
description="True if updating an existing provider, False if creating a new one",
|
||||
),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FullLLMProvider:
|
||||
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
|
||||
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
|
||||
# the result
|
||||
existing_provider = fetch_provider(db_session, llm_provider.name)
|
||||
if existing_provider and is_creation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"LLM Provider with name {llm_provider.name} already exists",
|
||||
)
|
||||
|
||||
try:
|
||||
return upsert_llm_provider(
|
||||
llm_provider=llm_provider,
|
||||
db_session=db_session,
|
||||
is_creation=is_creation,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to upsert LLM Provider")
|
||||
|
||||
@@ -18,7 +18,7 @@ from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.tag import get_tags_by_value_prefix_for_source_types
|
||||
from danswer.db.tag import find_tags
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.vespa.index import VespaIndex
|
||||
from danswer.one_shot_answer.answer_question import stream_search_answer
|
||||
@@ -99,12 +99,25 @@ def get_tags(
|
||||
if not allow_prefix:
|
||||
raise NotImplementedError("Cannot disable prefix match for now")
|
||||
|
||||
db_tags = get_tags_by_value_prefix_for_source_types(
|
||||
tag_key_prefix=match_pattern,
|
||||
tag_value_prefix=match_pattern,
|
||||
key_prefix = match_pattern
|
||||
value_prefix = match_pattern
|
||||
require_both_to_match = False
|
||||
|
||||
# split on = to allow the user to type in "author=bob"
|
||||
EQUAL_PAT = "="
|
||||
if match_pattern and EQUAL_PAT in match_pattern:
|
||||
split_pattern = match_pattern.split(EQUAL_PAT)
|
||||
key_prefix = split_pattern[0]
|
||||
value_prefix = EQUAL_PAT.join(split_pattern[1:])
|
||||
require_both_to_match = True
|
||||
|
||||
db_tags = find_tags(
|
||||
tag_key_prefix=key_prefix,
|
||||
tag_value_prefix=value_prefix,
|
||||
sources=sources,
|
||||
limit=limit,
|
||||
db_session=db_session,
|
||||
require_both_to_match=require_both_to_match,
|
||||
)
|
||||
server_tags = [
|
||||
SourceTag(
|
||||
|
||||
@@ -11,12 +11,25 @@ from danswer.server.settings.store import load_settings
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from ee.danswer.background.celery_utils import should_perform_chat_ttl_check
|
||||
from ee.danswer.background.celery_utils import should_perform_external_permissions_check
|
||||
from ee.danswer.background.celery_utils import (
|
||||
should_perform_external_doc_permissions_check,
|
||||
)
|
||||
from ee.danswer.background.celery_utils import (
|
||||
should_perform_external_group_permissions_check,
|
||||
)
|
||||
from ee.danswer.background.task_name_builders import name_chat_ttl_task
|
||||
from ee.danswer.background.task_name_builders import name_sync_external_permissions_task
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_doc_permissions_task,
|
||||
)
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_group_permissions_task,
|
||||
)
|
||||
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.danswer.external_permissions.permission_sync import (
|
||||
run_permission_sync_entrypoint,
|
||||
run_external_doc_permission_sync,
|
||||
)
|
||||
from ee.danswer.external_permissions.permission_sync import (
|
||||
run_external_group_permission_sync,
|
||||
)
|
||||
from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report
|
||||
|
||||
@@ -26,11 +39,18 @@ logger = setup_logger()
|
||||
global_version.set_ee()
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_sync_external_permissions_task)
|
||||
@build_celery_task_wrapper(name_sync_external_doc_permissions_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_external_permissions_task(cc_pair_id: int) -> None:
|
||||
def sync_external_doc_permissions_task(cc_pair_id: int) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
run_permission_sync_entrypoint(db_session=db_session, cc_pair_id=cc_pair_id)
|
||||
run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_sync_external_group_permissions_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_external_group_permissions_task(cc_pair_id: int) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_chat_ttl_task)
|
||||
@@ -44,18 +64,35 @@ def perform_ttl_management_task(retention_limit_days: int) -> None:
|
||||
# Periodic Tasks
|
||||
#####
|
||||
@celery_app.task(
|
||||
name="check_sync_external_permissions_task",
|
||||
name="check_sync_external_doc_permissions_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_sync_external_permissions_task() -> None:
|
||||
def check_sync_external_doc_permissions_task() -> None:
|
||||
"""Runs periodically to sync external permissions"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
if should_perform_external_permissions_check(
|
||||
if should_perform_external_doc_permissions_check(
|
||||
cc_pair=cc_pair, db_session=db_session
|
||||
):
|
||||
sync_external_permissions_task.apply_async(
|
||||
sync_external_doc_permissions_task.apply_async(
|
||||
kwargs=dict(cc_pair_id=cc_pair.id),
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="check_sync_external_group_permissions_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_sync_external_group_permissions_task() -> None:
|
||||
"""Runs periodically to sync external group permissions"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
if should_perform_external_group_permissions_check(
|
||||
cc_pair=cc_pair, db_session=db_session
|
||||
):
|
||||
sync_external_group_permissions_task.apply_async(
|
||||
kwargs=dict(cc_pair_id=cc_pair.id),
|
||||
)
|
||||
|
||||
@@ -94,9 +131,13 @@ def autogenerate_usage_report_task() -> None:
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
celery_app.conf.beat_schedule = {
|
||||
"sync-external-permissions": {
|
||||
"task": "check_sync_external_permissions_task",
|
||||
"schedule": timedelta(seconds=60), # TODO: optimize this
|
||||
"sync-external-doc-permissions": {
|
||||
"task": "check_sync_external_doc_permissions_task",
|
||||
"schedule": timedelta(seconds=5), # TODO: optimize this
|
||||
},
|
||||
"sync-external-group-permissions": {
|
||||
"task": "check_sync_external_group_permissions_task",
|
||||
"schedule": timedelta(seconds=5), # TODO: optimize this
|
||||
},
|
||||
"autogenerate_usage_report": {
|
||||
"task": "autogenerate_usage_report_task",
|
||||
|
||||
52
backend/ee/danswer/background/celery/tasks/vespa/tasks.py
Normal file
52
backend/ee/danswer/background/celery/tasks/vespa/tasks.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import cast
|
||||
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.user_group import delete_user_group
|
||||
from ee.danswer.db.user_group import fetch_user_group
|
||||
from ee.danswer.db.user_group import mark_user_group_as_synced
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def monitor_usergroup_taskset(key_bytes: bytes, r: Redis, db_session: Session) -> None:
|
||||
"""This function is likely to move in the worker refactor happening next."""
|
||||
key = key_bytes.decode("utf-8")
|
||||
usergroup_id = RedisUserGroup.get_id_from_fence_key(key)
|
||||
if not usergroup_id:
|
||||
task_logger.warning("Could not parse usergroup id from {key}")
|
||||
return
|
||||
|
||||
rug = RedisUserGroup(usergroup_id)
|
||||
fence_value = r.get(rug.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rug.taskset_key))
|
||||
task_logger.info(
|
||||
f"User group sync: usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
user_group = fetch_user_group(db_session=db_session, user_group_id=usergroup_id)
|
||||
if user_group:
|
||||
if user_group.is_up_for_deletion:
|
||||
delete_user_group(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(f"Deleted usergroup. id='{usergroup_id}'")
|
||||
else:
|
||||
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(f"Synced usergroup. id='{usergroup_id}'")
|
||||
|
||||
r.delete(rug.taskset_key)
|
||||
r.delete(rug.fence_key)
|
||||
@@ -1,21 +1,17 @@
|
||||
from typing import cast
|
||||
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.background.task_name_builders import name_chat_ttl_task
|
||||
from ee.danswer.background.task_name_builders import name_sync_external_permissions_task
|
||||
from ee.danswer.db.user_group import delete_user_group
|
||||
from ee.danswer.db.user_group import fetch_user_group
|
||||
from ee.danswer.db.user_group import mark_user_group_as_synced
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_doc_permissions_task,
|
||||
)
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_group_permissions_task,
|
||||
)
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -38,13 +34,13 @@ def should_perform_chat_ttl_check(
|
||||
return True
|
||||
|
||||
|
||||
def should_perform_external_permissions_check(
|
||||
def should_perform_external_doc_permissions_check(
|
||||
cc_pair: ConnectorCredentialPair, db_session: Session
|
||||
) -> bool:
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
task_name = name_sync_external_permissions_task(cc_pair_id=cc_pair.id)
|
||||
task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair.id)
|
||||
|
||||
latest_task = get_latest_task(task_name, db_session)
|
||||
if not latest_task:
|
||||
@@ -57,41 +53,20 @@ def should_perform_external_permissions_check(
|
||||
return True
|
||||
|
||||
|
||||
def monitor_usergroup_taskset(key_bytes: bytes, r: Redis) -> None:
|
||||
"""This function is likely to move in the worker refactor happening next."""
|
||||
key = key_bytes.decode("utf-8")
|
||||
usergroup_id = RedisUserGroup.get_id_from_fence_key(key)
|
||||
if not usergroup_id:
|
||||
task_logger.warning("Could not parse usergroup id from {key}")
|
||||
return
|
||||
def should_perform_external_group_permissions_check(
|
||||
cc_pair: ConnectorCredentialPair, db_session: Session
|
||||
) -> bool:
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
rug = RedisUserGroup(usergroup_id)
|
||||
fence_value = r.get(rug.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
task_name = name_sync_external_group_permissions_task(cc_pair_id=cc_pair.id)
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
latest_task = get_latest_task(task_name, db_session)
|
||||
if not latest_task:
|
||||
return True
|
||||
|
||||
count = cast(int, r.scard(rug.taskset_key))
|
||||
task_logger.info(
|
||||
f"User group sync: usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
return
|
||||
if check_task_is_live_and_not_timed_out(latest_task, db_session):
|
||||
logger.debug(f"{task_name} is already being performed. Skipping.")
|
||||
return False
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
user_group = fetch_user_group(db_session=db_session, user_group_id=usergroup_id)
|
||||
if user_group:
|
||||
if user_group.is_up_for_deletion:
|
||||
delete_user_group(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(f"Deleted usergroup. id='{usergroup_id}'")
|
||||
else:
|
||||
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(f"Synced usergroup. id='{usergroup_id}'")
|
||||
|
||||
r.delete(rug.taskset_key)
|
||||
r.delete(rug.fence_key)
|
||||
return True
|
||||
|
||||
@@ -2,5 +2,9 @@ def name_chat_ttl_task(retention_limit_days: int) -> str:
|
||||
return f"chat_ttl_{retention_limit_days}_days"
|
||||
|
||||
|
||||
def name_sync_external_permissions_task(cc_pair_id: int) -> str:
|
||||
return f"sync_external_permissions_task__{cc_pair_id}"
|
||||
def name_sync_external_doc_permissions_task(cc_pair_id: int) -> str:
|
||||
return f"sync_external_doc_permissions_task__{cc_pair_id}"
|
||||
|
||||
|
||||
def name_sync_external_group_permissions_task(cc_pair_id: int) -> str:
|
||||
return f"sync_external_group_permissions_task__{cc_pair_id}"
|
||||
|
||||
@@ -16,7 +16,9 @@ from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential__UserGroup
|
||||
from danswer.db.models import Document
|
||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.models import DocumentSet__UserGroup
|
||||
from danswer.db.models import LLMProvider__UserGroup
|
||||
from danswer.db.models import Persona__UserGroup
|
||||
from danswer.db.models import TokenRateLimit__UserGroup
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
@@ -32,6 +34,93 @@ from ee.danswer.server.user_group.models import UserGroupUpdate
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _cleanup_user__user_group_relationships__no_commit(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
user_ids: list[UUID] | None = None,
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
where_clause = User__UserGroup.user_group_id == user_group_id
|
||||
if user_ids:
|
||||
where_clause &= User__UserGroup.user_id.in_(user_ids)
|
||||
|
||||
user__user_group_relationships = db_session.scalars(
|
||||
select(User__UserGroup).where(where_clause)
|
||||
).all()
|
||||
for user__user_group_relationship in user__user_group_relationships:
|
||||
db_session.delete(user__user_group_relationship)
|
||||
|
||||
|
||||
def _cleanup_credential__user_group_relationships__no_commit(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
db_session.query(Credential__UserGroup).filter(
|
||||
Credential__UserGroup.user_group_id == user_group_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
|
||||
def _cleanup_llm_provider__user_group_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
db_session.query(LLMProvider__UserGroup).filter(
|
||||
LLMProvider__UserGroup.user_group_id == user_group_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
|
||||
def _cleanup_persona__user_group_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.user_group_id == user_group_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
|
||||
def _cleanup_token_rate_limit__user_group_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
token_rate_limit__user_group_relationships = db_session.scalars(
|
||||
select(TokenRateLimit__UserGroup).where(
|
||||
TokenRateLimit__UserGroup.user_group_id == user_group_id
|
||||
)
|
||||
).all()
|
||||
for (
|
||||
token_rate_limit__user_group_relationship
|
||||
) in token_rate_limit__user_group_relationships:
|
||||
db_session.delete(token_rate_limit__user_group_relationship)
|
||||
|
||||
|
||||
def _cleanup_user_group__cc_pair_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int, outdated_only: bool
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
stmt = select(UserGroup__ConnectorCredentialPair).where(
|
||||
UserGroup__ConnectorCredentialPair.user_group_id == user_group_id
|
||||
)
|
||||
if outdated_only:
|
||||
stmt = stmt.where(
|
||||
UserGroup__ConnectorCredentialPair.is_current == False # noqa: E712
|
||||
)
|
||||
user_group__cc_pair_relationships = db_session.scalars(stmt)
|
||||
for user_group__cc_pair_relationship in user_group__cc_pair_relationships:
|
||||
db_session.delete(user_group__cc_pair_relationship)
|
||||
|
||||
|
||||
def _cleanup_document_set__user_group_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
db_session.execute(
|
||||
delete(DocumentSet__UserGroup).where(
|
||||
DocumentSet__UserGroup.user_group_id == user_group_id
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def validate_user_creation_permissions(
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
@@ -62,8 +151,12 @@ def validate_user_creation_permissions(
|
||||
status_code=400,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
user_curated_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session, user_id=user.id, only_curator_groups=True
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
# Global curators can curate all groups they are member of
|
||||
only_curator_groups=user.role != UserRole.GLOBAL_CURATOR,
|
||||
)
|
||||
user_curated_group_ids = set([group.id for group in user_curated_groups])
|
||||
target_group_ids_set = set(target_group_ids)
|
||||
@@ -285,42 +378,6 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
|
||||
return db_user_group
|
||||
|
||||
|
||||
def _cleanup_user__user_group_relationships__no_commit(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
user_ids: list[UUID] | None = None,
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
where_clause = User__UserGroup.user_group_id == user_group_id
|
||||
if user_ids:
|
||||
where_clause &= User__UserGroup.user_id.in_(user_ids)
|
||||
|
||||
user__user_group_relationships = db_session.scalars(
|
||||
select(User__UserGroup).where(where_clause)
|
||||
).all()
|
||||
for user__user_group_relationship in user__user_group_relationships:
|
||||
db_session.delete(user__user_group_relationship)
|
||||
|
||||
|
||||
def _cleanup_credential__user_group_relationships__no_commit(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
db_session.query(Credential__UserGroup).filter(
|
||||
Credential__UserGroup.user_group_id == user_group_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
|
||||
def _cleanup_llm_provider__user_group_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
db_session.query(LLMProvider__UserGroup).filter(
|
||||
LLMProvider__UserGroup.user_group_id == user_group_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
|
||||
def _mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
@@ -475,21 +532,6 @@ def update_user_group(
|
||||
return db_user_group
|
||||
|
||||
|
||||
def _cleanup_token_rate_limit__user_group_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
token_rate_limit__user_group_relationships = db_session.scalars(
|
||||
select(TokenRateLimit__UserGroup).where(
|
||||
TokenRateLimit__UserGroup.user_group_id == user_group_id
|
||||
)
|
||||
).all()
|
||||
for (
|
||||
token_rate_limit__user_group_relationship
|
||||
) in token_rate_limit__user_group_relationships:
|
||||
db_session.delete(token_rate_limit__user_group_relationship)
|
||||
|
||||
|
||||
def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> None:
|
||||
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
|
||||
db_user_group = db_session.scalar(stmt)
|
||||
@@ -498,16 +540,31 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
|
||||
|
||||
_check_user_group_is_modifiable(db_user_group)
|
||||
|
||||
_mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
|
||||
_cleanup_credential__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
_cleanup_user__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
_mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
_cleanup_token_rate_limit__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
_cleanup_token_rate_limit__user_group_relationships__no_commit(
|
||||
_cleanup_document_set__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
_cleanup_persona__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
_cleanup_user_group__cc_pair_relationships__no_commit(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
outdated_only=False,
|
||||
)
|
||||
_cleanup_llm_provider__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
|
||||
@@ -516,20 +573,12 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _cleanup_user_group__cc_pair_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int, outdated_only: bool
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
stmt = select(UserGroup__ConnectorCredentialPair).where(
|
||||
UserGroup__ConnectorCredentialPair.user_group_id == user_group_id
|
||||
)
|
||||
if outdated_only:
|
||||
stmt = stmt.where(
|
||||
UserGroup__ConnectorCredentialPair.is_current == False # noqa: E712
|
||||
)
|
||||
user_group__cc_pair_relationships = db_session.scalars(stmt)
|
||||
for user_group__cc_pair_relationship in user_group__cc_pair_relationships:
|
||||
db_session.delete(user_group__cc_pair_relationship)
|
||||
def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
|
||||
"""
|
||||
This assumes that all the fk cleanup has already been done.
|
||||
"""
|
||||
db_session.delete(user_group)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> None:
|
||||
@@ -541,26 +590,6 @@ def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> Non
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
|
||||
_cleanup_llm_provider__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group.id
|
||||
)
|
||||
_cleanup_user__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group.id
|
||||
)
|
||||
_cleanup_user_group__cc_pair_relationships__no_commit(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group.id,
|
||||
outdated_only=False,
|
||||
)
|
||||
|
||||
# need to flush so that we don't get a foreign key error when deleting the user group row
|
||||
db_session.flush()
|
||||
|
||||
db_session.delete(user_group)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_user_group_cc_pair_relationship__no_commit(
|
||||
cc_pair_id: int, db_session: Session
|
||||
) -> None:
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
from typing import Any
|
||||
|
||||
from atlassian import Confluence # type:ignore
|
||||
|
||||
|
||||
def build_confluence_client(
|
||||
connector_specific_config: dict[str, Any], raw_credentials_json: dict[str, Any]
|
||||
) -> Confluence:
|
||||
is_cloud = connector_specific_config.get("is_cloud", False)
|
||||
return Confluence(
|
||||
api_version="cloud" if is_cloud else "latest",
|
||||
# Remove trailing slash from wiki_base if present
|
||||
url=connector_specific_config["wiki_base"].rstrip("/"),
|
||||
# passing in username causes issues for Confluence data center
|
||||
username=raw_credentials_json["confluence_username"] if is_cloud else None,
|
||||
password=raw_credentials_json["confluence_access_token"] if is_cloud else None,
|
||||
token=raw_credentials_json["confluence_access_token"] if not is_cloud else None,
|
||||
)
|
||||
@@ -1,19 +1,254 @@
|
||||
from typing import Any
|
||||
|
||||
from atlassian import Confluence # type:ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import ExternalAccess
|
||||
from danswer.connectors.confluence.confluence_utils import (
|
||||
build_confluence_document_id,
|
||||
)
|
||||
from danswer.connectors.confluence.rate_limit_handler import (
|
||||
make_confluence_call_handle_rate_limit,
|
||||
)
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
|
||||
from ee.danswer.db.document import upsert_document_external_perms__no_commit
|
||||
from ee.danswer.external_permissions.confluence.confluence_sync_utils import (
|
||||
build_confluence_client,
|
||||
)
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_REQUEST_PAGINATION_LIMIT = 100
|
||||
|
||||
|
||||
def _get_space_permissions(
|
||||
db_session: Session,
|
||||
confluence_client: Confluence,
|
||||
space_id: str,
|
||||
) -> ExternalAccess:
|
||||
get_space_permissions = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_space_permissions
|
||||
)
|
||||
|
||||
space_permissions = get_space_permissions(space_id).get("permissions", [])
|
||||
user_emails = set()
|
||||
# Confluence enforces that group names are unique
|
||||
group_names = set()
|
||||
is_externally_public = False
|
||||
for permission in space_permissions:
|
||||
subs = permission.get("subjects")
|
||||
if subs:
|
||||
# If there are subjects, then there are explicit users or groups with access
|
||||
if email := subs.get("user", {}).get("results", [{}])[0].get("email"):
|
||||
user_emails.add(email)
|
||||
if group_name := subs.get("group", {}).get("results", [{}])[0].get("name"):
|
||||
group_names.add(group_name)
|
||||
else:
|
||||
# If there are no subjects, then the permission is for everyone
|
||||
if permission.get("operation", {}).get(
|
||||
"operation"
|
||||
) == "read" and permission.get("anonymousAccess", False):
|
||||
# If the permission specifies read access for anonymous users, then
|
||||
# the space is publicly accessible
|
||||
is_externally_public = True
|
||||
batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session=db_session, emails=list(user_emails)
|
||||
)
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_names,
|
||||
is_public=is_externally_public,
|
||||
)
|
||||
|
||||
|
||||
def _get_restrictions_for_page(
|
||||
db_session: Session,
|
||||
page: dict[str, Any],
|
||||
space_permissions: ExternalAccess,
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
WARNING: This function includes no pagination. So if a page is private within
|
||||
the space and has over 200 users or over 200 groups with explicitly read access,
|
||||
this function will leave out some users or groups.
|
||||
200 is a large amount so it is unlikely, but just be aware.
|
||||
"""
|
||||
restrictions_json = page.get("restrictions", {})
|
||||
read_access_dict = restrictions_json.get("read", {}).get("restrictions", {})
|
||||
|
||||
read_access_user_jsons = read_access_dict.get("user", {}).get("results", [])
|
||||
read_access_group_jsons = read_access_dict.get("group", {}).get("results", [])
|
||||
|
||||
is_space_public = read_access_user_jsons == [] and read_access_group_jsons == []
|
||||
|
||||
if not is_space_public:
|
||||
read_access_user_emails = [
|
||||
user["email"] for user in read_access_user_jsons if user.get("email")
|
||||
]
|
||||
read_access_groups = [group["name"] for group in read_access_group_jsons]
|
||||
batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session=db_session, emails=list(read_access_user_emails)
|
||||
)
|
||||
external_access = ExternalAccess(
|
||||
external_user_emails=set(read_access_user_emails),
|
||||
external_user_group_ids=set(read_access_groups),
|
||||
is_public=False,
|
||||
)
|
||||
else:
|
||||
external_access = space_permissions
|
||||
|
||||
return external_access
|
||||
|
||||
|
||||
def _fetch_attachment_document_ids_for_page_paginated(
|
||||
confluence_client: Confluence, page: dict[str, Any]
|
||||
) -> list[str]:
|
||||
"""
|
||||
Starts by just extracting the first page of attachments from
|
||||
the page. If all attachments are in the first page, then
|
||||
no calls to the api are made from this function.
|
||||
"""
|
||||
get_attachments_from_content = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_attachments_from_content
|
||||
)
|
||||
|
||||
attachment_doc_ids = []
|
||||
attachments_dict = page["children"]["attachment"]
|
||||
start = 0
|
||||
|
||||
while True:
|
||||
attachments_list = attachments_dict["results"]
|
||||
attachment_doc_ids.extend(
|
||||
[
|
||||
build_confluence_document_id(
|
||||
base_url=confluence_client.url,
|
||||
content_url=attachment["_links"]["download"],
|
||||
)
|
||||
for attachment in attachments_list
|
||||
]
|
||||
)
|
||||
|
||||
if "next" not in attachments_dict["_links"]:
|
||||
break
|
||||
|
||||
start += len(attachments_list)
|
||||
attachments_dict = get_attachments_from_content(
|
||||
page_id=page["id"],
|
||||
start=start,
|
||||
limit=_REQUEST_PAGINATION_LIMIT,
|
||||
)
|
||||
|
||||
return attachment_doc_ids
|
||||
|
||||
|
||||
def _fetch_all_pages_paginated(
|
||||
confluence_client: Confluence,
|
||||
space_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
get_all_pages_from_space = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_all_pages_from_space
|
||||
)
|
||||
|
||||
# For each page, this fetches the page's attachments and restrictions.
|
||||
expansion_strings = [
|
||||
"children.attachment",
|
||||
"restrictions.read.restrictions.user",
|
||||
"restrictions.read.restrictions.group",
|
||||
]
|
||||
expansion_string = ",".join(expansion_strings)
|
||||
|
||||
all_pages = []
|
||||
start = 0
|
||||
while True:
|
||||
pages_dict = get_all_pages_from_space(
|
||||
space=space_id,
|
||||
start=start,
|
||||
limit=_REQUEST_PAGINATION_LIMIT,
|
||||
expand=expansion_string,
|
||||
)
|
||||
all_pages.extend(pages_dict)
|
||||
|
||||
response_size = len(pages_dict)
|
||||
if response_size < _REQUEST_PAGINATION_LIMIT:
|
||||
break
|
||||
start += response_size
|
||||
|
||||
return all_pages
|
||||
|
||||
|
||||
def _fetch_all_page_restrictions_for_space(
|
||||
db_session: Session,
|
||||
confluence_client: Confluence,
|
||||
space_id: str,
|
||||
space_permissions: ExternalAccess,
|
||||
) -> dict[str, ExternalAccess]:
|
||||
all_pages = _fetch_all_pages_paginated(
|
||||
confluence_client=confluence_client,
|
||||
space_id=space_id,
|
||||
)
|
||||
|
||||
document_restrictions: dict[str, ExternalAccess] = {}
|
||||
for page in all_pages:
|
||||
"""
|
||||
This assigns the same permissions to all attachments of a page and
|
||||
the page itself.
|
||||
This is because the attachments are stored in the same Confluence space as the page.
|
||||
WARNING: We create a dbDocument entry for all attachments, even though attachments
|
||||
may not be their own standalone documents. This is likely fine as we just upsert a
|
||||
document with just permissions.
|
||||
"""
|
||||
attachment_document_ids = [
|
||||
build_confluence_document_id(
|
||||
base_url=confluence_client.url,
|
||||
content_url=page["_links"]["webui"],
|
||||
)
|
||||
]
|
||||
attachment_document_ids.extend(
|
||||
_fetch_attachment_document_ids_for_page_paginated(
|
||||
confluence_client=confluence_client, page=page
|
||||
)
|
||||
)
|
||||
page_permissions = _get_restrictions_for_page(
|
||||
db_session=db_session,
|
||||
page=page,
|
||||
space_permissions=space_permissions,
|
||||
)
|
||||
for attachment_document_id in attachment_document_ids:
|
||||
document_restrictions[attachment_document_id] = page_permissions
|
||||
|
||||
return document_restrictions
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
docs_with_additional_info: list[DocsWithAdditionalInfo],
|
||||
sync_details: dict[str, Any],
|
||||
) -> None:
|
||||
logger.debug("Not yet implemented ACL sync for confluence, no-op")
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
it in postgres so that when it gets created later, the permissions are
|
||||
already populated
|
||||
"""
|
||||
confluence_client = build_confluence_client(
|
||||
cc_pair.connector.connector_specific_config, cc_pair.credential.credential_json
|
||||
)
|
||||
space_permissions = _get_space_permissions(
|
||||
db_session=db_session,
|
||||
confluence_client=confluence_client,
|
||||
space_id=cc_pair.connector.connector_specific_config["space"],
|
||||
)
|
||||
fresh_doc_permissions = _fetch_all_page_restrictions_for_space(
|
||||
db_session=db_session,
|
||||
confluence_client=confluence_client,
|
||||
space_id=cc_pair.connector.connector_specific_config["space"],
|
||||
space_permissions=space_permissions,
|
||||
)
|
||||
for doc_id, ext_access in fresh_doc_permissions.items():
|
||||
upsert_document_external_perms__no_commit(
|
||||
db_session=db_session,
|
||||
doc_id=doc_id,
|
||||
external_access=ext_access,
|
||||
source_type=cc_pair.connector.source,
|
||||
)
|
||||
|
||||
@@ -1,19 +1,107 @@
|
||||
from typing import Any
|
||||
from collections.abc import Iterator
|
||||
|
||||
from atlassian import Confluence # type:ignore
|
||||
from requests import HTTPError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.connectors.confluence.rate_limit_handler import (
|
||||
make_confluence_call_handle_rate_limit,
|
||||
)
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
|
||||
from ee.danswer.external_permissions.confluence.confluence_sync_utils import (
|
||||
build_confluence_client,
|
||||
)
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_PAGE_SIZE = 100
|
||||
|
||||
|
||||
def _get_confluence_group_names_paginated(
|
||||
confluence_client: Confluence,
|
||||
) -> Iterator[str]:
|
||||
get_all_groups = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_all_groups
|
||||
)
|
||||
|
||||
start = 0
|
||||
while True:
|
||||
try:
|
||||
groups = get_all_groups(start=start, limit=_PAGE_SIZE)
|
||||
except HTTPError as e:
|
||||
if e.response.status_code in (403, 404):
|
||||
return
|
||||
raise e
|
||||
|
||||
for group in groups:
|
||||
if group_name := group.get("name"):
|
||||
yield group_name
|
||||
|
||||
if len(groups) < _PAGE_SIZE:
|
||||
break
|
||||
start += _PAGE_SIZE
|
||||
|
||||
|
||||
def _get_group_members_email_paginated(
|
||||
confluence_client: Confluence,
|
||||
group_name: str,
|
||||
) -> list[str]:
|
||||
get_group_members = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_group_members
|
||||
)
|
||||
group_member_emails: list[str] = []
|
||||
start = 0
|
||||
while True:
|
||||
try:
|
||||
members = get_group_members(
|
||||
group_name=group_name, start=start, limit=_PAGE_SIZE
|
||||
)
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 403 or e.response.status_code == 404:
|
||||
return group_member_emails
|
||||
raise e
|
||||
|
||||
group_member_emails.extend(
|
||||
[member.get("email") for member in members if member.get("email")]
|
||||
)
|
||||
if len(members) < _PAGE_SIZE:
|
||||
break
|
||||
start += _PAGE_SIZE
|
||||
return group_member_emails
|
||||
|
||||
|
||||
def confluence_group_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
docs_with_additional_info: list[DocsWithAdditionalInfo],
|
||||
sync_details: dict[str, Any],
|
||||
) -> None:
|
||||
logger.debug("Not yet implemented group sync for confluence, no-op")
|
||||
confluence_client = build_confluence_client(
|
||||
cc_pair.connector.connector_specific_config, cc_pair.credential.credential_json
|
||||
)
|
||||
|
||||
danswer_groups: list[ExternalUserGroup] = []
|
||||
# Confluence enforces that group names are unique
|
||||
for group_name in _get_confluence_group_names_paginated(confluence_client):
|
||||
group_member_emails = _get_group_members_email_paginated(
|
||||
confluence_client, group_name
|
||||
)
|
||||
group_members = batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session=db_session, emails=group_member_emails
|
||||
)
|
||||
if group_members:
|
||||
danswer_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=group_name, user_ids=[user.id for user in group_members]
|
||||
)
|
||||
)
|
||||
|
||||
replace_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
group_defs=danswer_groups,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -8,15 +10,17 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import ExternalAccess
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.google_drive.connector_auth import (
|
||||
get_google_drive_creds,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.document import upsert_document_external_perms__no_commit
|
||||
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
|
||||
|
||||
# Google Drive APIs are quite flakey and may 500 for an
|
||||
# extended period of time. Trying to combat here by adding a very
|
||||
@@ -27,6 +31,42 @@ add_retries = retry_builder(tries=5, delay=5, max_delay=30)
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_docs_with_additional_info(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> dict[str, Any]:
|
||||
# Get all document ids that need their permissions updated
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session=db_session,
|
||||
source=cc_pair.connector.source,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config=cc_pair.connector.connector_specific_config,
|
||||
credential=cc_pair.credential,
|
||||
)
|
||||
|
||||
assert isinstance(runnable_connector, PollConnector)
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
start_time = (
|
||||
cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc).timestamp()
|
||||
if cc_pair.last_time_perm_sync
|
||||
else 0.0
|
||||
)
|
||||
cc_pair.last_time_perm_sync = current_time
|
||||
|
||||
doc_batch_generator = runnable_connector.poll_source(
|
||||
start=start_time, end=current_time.timestamp()
|
||||
)
|
||||
|
||||
docs_with_additional_info = {
|
||||
doc.id: doc.additional_info
|
||||
for doc_batch in doc_batch_generator
|
||||
for doc in doc_batch
|
||||
}
|
||||
|
||||
return docs_with_additional_info
|
||||
|
||||
|
||||
def _fetch_permissions_paginated(
|
||||
drive_service: Any, drive_file_id: str
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
@@ -122,8 +162,6 @@ def _fetch_google_permissions_for_document_id(
|
||||
def gdrive_doc_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
docs_with_additional_info: list[DocsWithAdditionalInfo],
|
||||
sync_details: dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -131,10 +169,24 @@ def gdrive_doc_sync(
|
||||
it in postgres so that when it gets created later, the permissions are
|
||||
already populated
|
||||
"""
|
||||
for doc in docs_with_additional_info:
|
||||
sync_details = cc_pair.auto_sync_options
|
||||
if sync_details is None:
|
||||
logger.error("Sync details not found for Google Drive")
|
||||
raise ValueError("Sync details not found for Google Drive")
|
||||
|
||||
# Here we run the connector to grab all the ids
|
||||
# this may grab ids before they are indexed but that is fine because
|
||||
# we create a document in postgres to hold the permissions info
|
||||
# until the indexing job has a chance to run
|
||||
docs_with_additional_info = _get_docs_with_additional_info(
|
||||
db_session=db_session,
|
||||
cc_pair=cc_pair,
|
||||
)
|
||||
|
||||
for doc_id, doc_additional_info in docs_with_additional_info.items():
|
||||
ext_access = _fetch_google_permissions_for_document_id(
|
||||
db_session=db_session,
|
||||
drive_file_id=doc.additional_info,
|
||||
drive_file_id=doc_additional_info,
|
||||
raw_credentials_json=cc_pair.credential.credential_json,
|
||||
company_google_domains=[
|
||||
cast(dict[str, str], sync_details)["company_domain"]
|
||||
@@ -142,7 +194,7 @@ def gdrive_doc_sync(
|
||||
)
|
||||
upsert_document_external_perms__no_commit(
|
||||
db_session=db_session,
|
||||
doc_id=doc.id,
|
||||
doc_id=doc_id,
|
||||
external_access=ext_access,
|
||||
source_type=cc_pair.connector.source,
|
||||
)
|
||||
|
||||
@@ -17,7 +17,6 @@ from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
|
||||
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -105,9 +104,12 @@ def _fetch_group_members_paginated(
|
||||
def gdrive_group_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
docs_with_additional_info: list[DocsWithAdditionalInfo],
|
||||
sync_details: dict[str, Any],
|
||||
) -> None:
|
||||
sync_details = cc_pair.auto_sync_options
|
||||
if sync_details is None:
|
||||
logger.error("Sync details not found for Google Drive")
|
||||
raise ValueError("Sync details not found for Google Drive")
|
||||
|
||||
google_drive_creds, _ = get_google_drive_creds(
|
||||
cc_pair.credential.credential_json,
|
||||
scopes=FETCH_GROUPS_SCOPES,
|
||||
|
||||
@@ -5,31 +5,79 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_documents
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.db.document import get_document_ids_for_connector_credential_pair
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.document_index.factory import get_current_primary_default_document_index
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.external_permissions.permission_sync_function_map import (
|
||||
DOC_PERMISSIONS_FUNC_MAP,
|
||||
)
|
||||
from ee.danswer.external_permissions.permission_sync_function_map import (
|
||||
FULL_FETCH_PERIOD_IN_SECONDS,
|
||||
)
|
||||
from ee.danswer.external_permissions.permission_sync_function_map import (
|
||||
GROUP_PERMISSIONS_FUNC_MAP,
|
||||
)
|
||||
from ee.danswer.external_permissions.permission_sync_utils import (
|
||||
get_docs_with_additional_info,
|
||||
)
|
||||
from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
||||
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
|
||||
from ee.danswer.external_permissions.sync_params import PERMISSION_SYNC_PERIODS
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def run_permission_sync_entrypoint(
|
||||
def _is_time_to_run_sync(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
source_sync_period = PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
|
||||
|
||||
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
|
||||
if not source_sync_period:
|
||||
return True
|
||||
|
||||
# If the last sync is None, it has never been run so we run the sync
|
||||
if cc_pair.last_time_perm_sync is None:
|
||||
return True
|
||||
|
||||
last_sync = cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc)
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
# If the last sync is greater than the full fetch period, we run the sync
|
||||
if (current_time - last_sync).total_seconds() > source_sync_period:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def run_external_group_permission_sync(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
) -> None:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"No connector credential pair found for id: {cc_pair_id}")
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
|
||||
if group_sync_func is None:
|
||||
# Not all sync connectors support group permissions so this is fine
|
||||
return
|
||||
|
||||
if not _is_time_to_run_sync(cc_pair):
|
||||
return
|
||||
|
||||
try:
|
||||
# This function updates:
|
||||
# - the user_email <-> external_user_group_id mapping
|
||||
# in postgres without committing
|
||||
logger.debug(f"Syncing groups for {source_type}")
|
||||
if group_sync_func is not None:
|
||||
group_sync_func(
|
||||
db_session,
|
||||
cc_pair,
|
||||
)
|
||||
|
||||
# update postgres
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating document index: {e}")
|
||||
db_session.rollback()
|
||||
|
||||
|
||||
def run_external_doc_permission_sync(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
) -> None:
|
||||
# TODO: seperate out group and doc sync
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"No connector credential pair found for id: {cc_pair_id}")
|
||||
@@ -37,90 +85,57 @@ def run_permission_sync_entrypoint(
|
||||
source_type = cc_pair.connector.source
|
||||
|
||||
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
|
||||
if doc_sync_func is None:
|
||||
raise ValueError(
|
||||
f"No permission sync function found for source type: {source_type}"
|
||||
)
|
||||
|
||||
sync_details = cc_pair.auto_sync_options
|
||||
if sync_details is None:
|
||||
raise ValueError(f"No auto sync options found for source type: {source_type}")
|
||||
|
||||
# If the source type is not polling, we only fetch the permissions every
|
||||
# _FULL_FETCH_PERIOD_IN_SECONDS seconds
|
||||
full_fetch_period = FULL_FETCH_PERIOD_IN_SECONDS[source_type]
|
||||
if full_fetch_period is not None:
|
||||
last_sync = cc_pair.last_time_perm_sync
|
||||
if (
|
||||
last_sync
|
||||
and (
|
||||
datetime.now(timezone.utc) - last_sync.replace(tzinfo=timezone.utc)
|
||||
).total_seconds()
|
||||
< full_fetch_period
|
||||
):
|
||||
return
|
||||
|
||||
# Here we run the connector to grab all the ids
|
||||
# this may grab ids before they are indexed but that is fine because
|
||||
# we create a document in postgres to hold the permissions info
|
||||
# until the indexing job has a chance to run
|
||||
docs_with_additional_info = get_docs_with_additional_info(
|
||||
db_session=db_session,
|
||||
cc_pair=cc_pair,
|
||||
)
|
||||
|
||||
# This function updates:
|
||||
# - the user_email <-> external_user_group_id mapping
|
||||
# in postgres without committing
|
||||
logger.debug(f"Syncing groups for {source_type}")
|
||||
if group_sync_func is not None:
|
||||
group_sync_func(
|
||||
db_session,
|
||||
cc_pair,
|
||||
docs_with_additional_info,
|
||||
sync_details,
|
||||
)
|
||||
|
||||
# This function updates:
|
||||
# - the user_email <-> document mapping
|
||||
# - the external_user_group_id <-> document mapping
|
||||
# in postgres without committing
|
||||
logger.debug(f"Syncing docs for {source_type}")
|
||||
doc_sync_func(
|
||||
db_session,
|
||||
cc_pair,
|
||||
docs_with_additional_info,
|
||||
sync_details,
|
||||
)
|
||||
|
||||
# This function fetches the updated access for the documents
|
||||
# and returns a dictionary of document_ids and access
|
||||
# This is the access we want to update vespa with
|
||||
docs_access = get_access_for_documents(
|
||||
document_ids=[doc.id for doc in docs_with_additional_info],
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Then we build the update requests to update vespa
|
||||
update_reqs = [
|
||||
UpdateRequest(document_ids=[doc_id], access=doc_access)
|
||||
for doc_id, doc_access in docs_access.items()
|
||||
]
|
||||
|
||||
# Don't bother sync-ing secondary, it will be sync-ed after switch anyway
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=search_settings.index_name,
|
||||
secondary_index_name=None,
|
||||
)
|
||||
if not _is_time_to_run_sync(cc_pair):
|
||||
return
|
||||
|
||||
try:
|
||||
# This function updates:
|
||||
# - the user_email <-> document mapping
|
||||
# - the external_user_group_id <-> document mapping
|
||||
# in postgres without committing
|
||||
logger.debug(f"Syncing docs for {source_type}")
|
||||
doc_sync_func(
|
||||
db_session,
|
||||
cc_pair,
|
||||
)
|
||||
|
||||
# Get the document ids for the cc pair
|
||||
document_ids_for_cc_pair = get_document_ids_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
# This function fetches the updated access for the documents
|
||||
# and returns a dictionary of document_ids and access
|
||||
# This is the access we want to update vespa with
|
||||
docs_access = get_access_for_documents(
|
||||
document_ids=document_ids_for_cc_pair,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Then we build the update requests to update vespa
|
||||
update_reqs = [
|
||||
UpdateRequest(document_ids=[doc_id], access=doc_access)
|
||||
for doc_id, doc_access in docs_access.items()
|
||||
]
|
||||
|
||||
# Don't bother sync-ing secondary, it will be sync-ed after switch anyway
|
||||
document_index = get_current_primary_default_document_index(db_session)
|
||||
|
||||
# update vespa
|
||||
document_index.update(update_reqs)
|
||||
|
||||
cc_pair.last_time_perm_sync = datetime.now(timezone.utc)
|
||||
|
||||
# update postgres
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating document index: {e}")
|
||||
logger.error(f"Error Syncing Permissions: {e}")
|
||||
db_session.rollback()
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class DocsWithAdditionalInfo(BaseModel):
|
||||
id: str
|
||||
additional_info: Any
|
||||
|
||||
|
||||
def get_docs_with_additional_info(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[DocsWithAdditionalInfo]:
|
||||
# Get all document ids that need their permissions updated
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session=db_session,
|
||||
source=cc_pair.connector.source,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config=cc_pair.connector.connector_specific_config,
|
||||
credential=cc_pair.credential,
|
||||
)
|
||||
|
||||
assert isinstance(runnable_connector, PollConnector)
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
start_time = (
|
||||
cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc).timestamp()
|
||||
if cc_pair.last_time_perm_sync
|
||||
else 0
|
||||
)
|
||||
cc_pair.last_time_perm_sync = current_time
|
||||
|
||||
doc_batch_generator = runnable_connector.poll_source(
|
||||
start=start_time, end=current_time.timestamp()
|
||||
)
|
||||
|
||||
docs_with_additional_info = [
|
||||
DocsWithAdditionalInfo(id=doc.id, additional_info=doc.additional_info)
|
||||
for doc_batch in doc_batch_generator
|
||||
for doc in doc_batch
|
||||
]
|
||||
logger.debug(f"Docs with additional info: {len(docs_with_additional_info)}")
|
||||
|
||||
return docs_with_additional_info
|
||||
192
backend/ee/danswer/external_permissions/slack/doc_sync.py
Normal file
192
backend/ee/danswer/external_permissions/slack/doc_sync.py
Normal file
@@ -0,0 +1,192 @@
|
||||
from slack_sdk import WebClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import ExternalAccess
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.connectors.slack.connector import get_channels
|
||||
from danswer.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.document import upsert_document_external_perms__no_commit
|
||||
from ee.danswer.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _extract_channel_id_from_doc_id(doc_id: str) -> str:
|
||||
"""
|
||||
Extracts the channel ID from a document ID string.
|
||||
|
||||
The document ID is expected to be in the format: "{channel_id}__{message_ts}"
|
||||
|
||||
Args:
|
||||
doc_id (str): The document ID string.
|
||||
|
||||
Returns:
|
||||
str: The extracted channel ID.
|
||||
|
||||
Raises:
|
||||
ValueError: If the doc_id doesn't contain the expected separator.
|
||||
"""
|
||||
try:
|
||||
channel_id, _ = doc_id.split("__", 1)
|
||||
return channel_id
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid doc_id format: {doc_id}")
|
||||
|
||||
|
||||
def _get_slack_document_ids_and_channels(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> dict[str, list[str]]:
|
||||
# Get all document ids that need their permissions updated
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session=db_session,
|
||||
source=cc_pair.connector.source,
|
||||
input_type=InputType.PRUNE,
|
||||
connector_specific_config=cc_pair.connector.connector_specific_config,
|
||||
credential=cc_pair.credential,
|
||||
)
|
||||
|
||||
assert isinstance(runnable_connector, IdConnector)
|
||||
|
||||
channel_doc_map: dict[str, list[str]] = {}
|
||||
for doc_id in runnable_connector.retrieve_all_source_ids():
|
||||
channel_id = _extract_channel_id_from_doc_id(doc_id)
|
||||
if channel_id not in channel_doc_map:
|
||||
channel_doc_map[channel_id] = []
|
||||
channel_doc_map[channel_id].append(doc_id)
|
||||
|
||||
return channel_doc_map
|
||||
|
||||
|
||||
def _fetch_worspace_permissions(
|
||||
db_session: Session,
|
||||
user_id_to_email_map: dict[str, str],
|
||||
) -> ExternalAccess:
|
||||
user_emails = set()
|
||||
for email in user_id_to_email_map.values():
|
||||
user_emails.add(email)
|
||||
batch_add_non_web_user_if_not_exists__no_commit(db_session, list(user_emails))
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
# No group<->document mapping for slack
|
||||
external_user_group_ids=set(),
|
||||
# No way to determine if slack is invite only without enterprise liscense
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
|
||||
def _fetch_channel_permissions(
|
||||
db_session: Session,
|
||||
slack_client: WebClient,
|
||||
workspace_permissions: ExternalAccess,
|
||||
user_id_to_email_map: dict[str, str],
|
||||
) -> dict[str, ExternalAccess]:
|
||||
channel_permissions = {}
|
||||
public_channels = get_channels(
|
||||
client=slack_client,
|
||||
get_public=True,
|
||||
get_private=False,
|
||||
)
|
||||
public_channel_ids = [
|
||||
channel["id"] for channel in public_channels if "id" in channel
|
||||
]
|
||||
for channel_id in public_channel_ids:
|
||||
channel_permissions[channel_id] = workspace_permissions
|
||||
|
||||
private_channels = get_channels(
|
||||
client=slack_client,
|
||||
get_public=False,
|
||||
get_private=True,
|
||||
)
|
||||
private_channel_ids = [
|
||||
channel["id"] for channel in private_channels if "id" in channel
|
||||
]
|
||||
|
||||
for channel_id in private_channel_ids:
|
||||
# Collect all member ids for the channel pagination calls
|
||||
member_ids = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
slack_client.conversations_members,
|
||||
channel=channel_id,
|
||||
):
|
||||
member_ids.extend(result.get("members", []))
|
||||
|
||||
# Collect all member emails for the channel
|
||||
member_emails = set()
|
||||
for member_id in member_ids:
|
||||
member_email = user_id_to_email_map.get(member_id)
|
||||
|
||||
if not member_email:
|
||||
# If the user is an external user, they wont get returned from the
|
||||
# conversations_members call so we need to make a separate call to users_info
|
||||
# and add them to the user_id_to_email_map
|
||||
member_info = slack_client.users_info(user=member_id)
|
||||
member_email = member_info["user"]["profile"].get("email")
|
||||
if not member_email:
|
||||
# If no email is found, we skip the user
|
||||
continue
|
||||
user_id_to_email_map[member_id] = member_email
|
||||
batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session, [member_email]
|
||||
)
|
||||
|
||||
member_emails.add(member_email)
|
||||
|
||||
channel_permissions[channel_id] = ExternalAccess(
|
||||
external_user_emails=member_emails,
|
||||
# No group<->document mapping for slack
|
||||
external_user_group_ids=set(),
|
||||
# No way to determine if slack is invite only without enterprise liscense
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
return channel_permissions
|
||||
|
||||
|
||||
def slack_doc_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> None:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
it in postgres so that when it gets created later, the permissions are
|
||||
already populated
|
||||
"""
|
||||
slack_client = WebClient(
|
||||
token=cc_pair.credential.credential_json["slack_bot_token"]
|
||||
)
|
||||
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
|
||||
channel_doc_map = _get_slack_document_ids_and_channels(
|
||||
db_session=db_session,
|
||||
cc_pair=cc_pair,
|
||||
)
|
||||
workspace_permissions = _fetch_worspace_permissions(
|
||||
db_session=db_session,
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
)
|
||||
channel_permissions = _fetch_channel_permissions(
|
||||
db_session=db_session,
|
||||
slack_client=slack_client,
|
||||
workspace_permissions=workspace_permissions,
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
)
|
||||
for channel_id, ext_access in channel_permissions.items():
|
||||
doc_ids = channel_doc_map.get(channel_id)
|
||||
if not doc_ids:
|
||||
# No documents found for channel the channel_id
|
||||
continue
|
||||
|
||||
for doc_id in doc_ids:
|
||||
upsert_document_external_perms__no_commit(
|
||||
db_session=db_session,
|
||||
doc_id=doc_id,
|
||||
external_access=ext_access,
|
||||
source_type=cc_pair.connector.source,
|
||||
)
|
||||
92
backend/ee/danswer/external_permissions/slack/group_sync.py
Normal file
92
backend/ee/danswer/external_permissions/slack/group_sync.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
THIS IS NOT USEFUL OR USED FOR PERMISSION SYNCING
|
||||
WHEN USERGROUPS ARE ADDED TO A CHANNEL, IT JUST RESOLVES ALL THE USERS TO THAT CHANNEL
|
||||
SO WHEN CHECKING IF A USER CAN ACCESS A DOCUMENT, WE ONLY NEED TO CHECK THEIR EMAIL
|
||||
THERE IS NO USERGROUP <-> DOCUMENT PERMISSION MAPPING
|
||||
"""
|
||||
from slack_sdk import WebClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
|
||||
from ee.danswer.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_slack_group_ids(
|
||||
slack_client: WebClient,
|
||||
) -> list[str]:
|
||||
group_ids = []
|
||||
for result in make_paginated_slack_api_call_w_retries(slack_client.usergroups_list):
|
||||
for group in result.get("usergroups", []):
|
||||
group_ids.append(group.get("id"))
|
||||
return group_ids
|
||||
|
||||
|
||||
def _get_slack_group_members_email(
|
||||
db_session: Session,
|
||||
slack_client: WebClient,
|
||||
group_name: str,
|
||||
user_id_to_email_map: dict[str, str],
|
||||
) -> list[str]:
|
||||
group_member_emails = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
slack_client.usergroups_users_list, usergroup=group_name
|
||||
):
|
||||
for member_id in result.get("users", []):
|
||||
member_email = user_id_to_email_map.get(member_id)
|
||||
if not member_email:
|
||||
# If the user is an external user, they wont get returned from the
|
||||
# conversations_members call so we need to make a separate call to users_info
|
||||
member_info = slack_client.users_info(user=member_id)
|
||||
member_email = member_info["user"]["profile"].get("email")
|
||||
if not member_email:
|
||||
# If no email is found, we skip the user
|
||||
continue
|
||||
user_id_to_email_map[member_id] = member_email
|
||||
batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session, [member_email]
|
||||
)
|
||||
group_member_emails.append(member_email)
|
||||
|
||||
return group_member_emails
|
||||
|
||||
|
||||
def slack_group_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> None:
|
||||
slack_client = WebClient(
|
||||
token=cc_pair.credential.credential_json["slack_bot_token"]
|
||||
)
|
||||
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
|
||||
|
||||
danswer_groups: list[ExternalUserGroup] = []
|
||||
for group_name in _get_slack_group_ids(slack_client):
|
||||
group_member_emails = _get_slack_group_members_email(
|
||||
db_session=db_session,
|
||||
slack_client=slack_client,
|
||||
group_name=group_name,
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
)
|
||||
group_members = batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session=db_session, emails=group_member_emails
|
||||
)
|
||||
if group_members:
|
||||
danswer_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=group_name, user_ids=[user.id for user in group_members]
|
||||
)
|
||||
)
|
||||
|
||||
replace_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
group_defs=danswer_groups,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
18
backend/ee/danswer/external_permissions/slack/utils.py
Normal file
18
backend/ee/danswer/external_permissions/slack/utils.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from danswer.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
|
||||
|
||||
def fetch_user_id_to_email_map(
|
||||
slack_client: WebClient,
|
||||
) -> dict[str, str]:
|
||||
user_id_to_email_map = {}
|
||||
for user_info in make_paginated_slack_api_call_w_retries(
|
||||
slack_client.users_list,
|
||||
):
|
||||
for user in user_info.get("members", []):
|
||||
if user.get("profile", {}).get("email"):
|
||||
user_id_to_email_map[user.get("id")] = user.get("profile", {}).get(
|
||||
"email"
|
||||
)
|
||||
return user_id_to_email_map
|
||||
@@ -1,5 +1,4 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -9,15 +8,14 @@ from ee.danswer.external_permissions.confluence.doc_sync import confluence_doc_s
|
||||
from ee.danswer.external_permissions.confluence.group_sync import confluence_group_sync
|
||||
from ee.danswer.external_permissions.google_drive.doc_sync import gdrive_doc_sync
|
||||
from ee.danswer.external_permissions.google_drive.group_sync import gdrive_group_sync
|
||||
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
|
||||
from ee.danswer.external_permissions.slack.doc_sync import slack_doc_sync
|
||||
|
||||
GroupSyncFuncType = Callable[
|
||||
[Session, ConnectorCredentialPair, list[DocsWithAdditionalInfo], dict[str, Any]],
|
||||
None,
|
||||
]
|
||||
|
||||
DocSyncFuncType = Callable[
|
||||
[Session, ConnectorCredentialPair, list[DocsWithAdditionalInfo], dict[str, Any]],
|
||||
# Defining the input/output types for the sync functions
|
||||
SyncFuncType = Callable[
|
||||
[
|
||||
Session,
|
||||
ConnectorCredentialPair,
|
||||
],
|
||||
None,
|
||||
]
|
||||
|
||||
@@ -26,27 +24,27 @@ DocSyncFuncType = Callable[
|
||||
# - the external_user_group_id <-> document mapping
|
||||
# in postgres without committing
|
||||
# THIS ONE IS NECESSARY FOR AUTO SYNC TO WORK
|
||||
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, DocSyncFuncType] = {
|
||||
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
|
||||
DocumentSource.GOOGLE_DRIVE: gdrive_doc_sync,
|
||||
DocumentSource.CONFLUENCE: confluence_doc_sync,
|
||||
DocumentSource.SLACK: slack_doc_sync,
|
||||
}
|
||||
|
||||
# These functions update:
|
||||
# - the user_email <-> external_user_group_id mapping
|
||||
# in postgres without committing
|
||||
# THIS ONE IS OPTIONAL ON AN APP BY APP BASIS
|
||||
GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, GroupSyncFuncType] = {
|
||||
GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
|
||||
DocumentSource.GOOGLE_DRIVE: gdrive_group_sync,
|
||||
DocumentSource.CONFLUENCE: confluence_group_sync,
|
||||
}
|
||||
|
||||
|
||||
# None means that the connector supports polling from last_time_perm_sync to now
|
||||
FULL_FETCH_PERIOD_IN_SECONDS: dict[DocumentSource, int | None] = {
|
||||
# Polling is supported
|
||||
DocumentSource.GOOGLE_DRIVE: None,
|
||||
# Polling is not supported so we fetch all doc permissions every 10 minutes
|
||||
DocumentSource.CONFLUENCE: 10 * 60,
|
||||
# If nothing is specified here, we run the doc_sync every time the celery beat runs
|
||||
PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
# Polling is not supported so we fetch all doc permissions every 5 minutes
|
||||
DocumentSource.CONFLUENCE: 5 * 60,
|
||||
DocumentSource.SLACK: 5 * 60,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -6,8 +7,20 @@ from pydantic import Field
|
||||
|
||||
class NavigationItem(BaseModel):
|
||||
link: str
|
||||
icon: str
|
||||
title: str
|
||||
# Right now must be one of the FA icons
|
||||
icon: str | None = None
|
||||
# NOTE: SVG must not have a width / height specified
|
||||
# This is the actual SVG as a string. Done this way to reduce
|
||||
# complexity / having to store additional "logos" in Postgres
|
||||
svg_logo: str | None = None
|
||||
|
||||
@classmethod
|
||||
def model_validate(cls, *args: Any, **kwargs: Any) -> "NavigationItem":
|
||||
instance = super().model_validate(*args, **kwargs)
|
||||
if bool(instance.icon) == bool(instance.svg_logo):
|
||||
raise ValueError("Exactly one of fa_icon or svg_logo must be specified")
|
||||
return instance
|
||||
|
||||
|
||||
class EnterpriseSettings(BaseModel):
|
||||
|
||||
@@ -12,7 +12,6 @@ from fastapi_users import exceptions
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from onelogin.saml2.auth import OneLogin_Saml2_Auth # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from pydantic import EmailStr
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.schemas import UserCreate
|
||||
@@ -61,7 +60,7 @@ async def upsert_saml_user(email: str) -> User:
|
||||
|
||||
user: User = await user_manager.create(
|
||||
UserCreate(
|
||||
email=EmailStr(email),
|
||||
email=email,
|
||||
password=hashed_pass,
|
||||
is_verified=True,
|
||||
role=role,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
@@ -22,6 +23,7 @@ from ee.danswer.db.standard_answer import (
|
||||
)
|
||||
from ee.danswer.server.enterprise_settings.models import AnalyticsScriptUpload
|
||||
from ee.danswer.server.enterprise_settings.models import EnterpriseSettings
|
||||
from ee.danswer.server.enterprise_settings.models import NavigationItem
|
||||
from ee.danswer.server.enterprise_settings.store import store_analytics_script
|
||||
from ee.danswer.server.enterprise_settings.store import (
|
||||
store_settings as store_ee_settings,
|
||||
@@ -44,6 +46,13 @@ logger = setup_logger()
|
||||
_SEED_CONFIG_ENV_VAR_NAME = "ENV_SEED_CONFIGURATION"
|
||||
|
||||
|
||||
class NavigationItemSeed(BaseModel):
|
||||
link: str
|
||||
title: str
|
||||
# NOTE: SVG at this path must not have a width / height specified
|
||||
svg_path: str
|
||||
|
||||
|
||||
class SeedConfiguration(BaseModel):
|
||||
llms: list[LLMProviderUpsertRequest] | None = None
|
||||
admin_user_emails: list[str] | None = None
|
||||
@@ -51,6 +60,10 @@ class SeedConfiguration(BaseModel):
|
||||
personas: list[CreatePersonaRequest] | None = None
|
||||
settings: Settings | None = None
|
||||
enterprise_settings: EnterpriseSettings | None = None
|
||||
|
||||
# allows for specifying custom navigation items that have your own custom SVG logos
|
||||
nav_item_overrides: list[NavigationItemSeed] | None = None
|
||||
|
||||
# Use existing `CUSTOM_ANALYTICS_SECRET_KEY` for reference
|
||||
analytics_script_path: str | None = None
|
||||
custom_tools: List[CustomToolSeed] | None = None
|
||||
@@ -60,7 +73,7 @@ def _parse_env() -> SeedConfiguration | None:
|
||||
seed_config_str = os.getenv(_SEED_CONFIG_ENV_VAR_NAME)
|
||||
if not seed_config_str:
|
||||
return None
|
||||
seed_config = SeedConfiguration.parse_raw(seed_config_str)
|
||||
seed_config = SeedConfiguration.model_validate_json(seed_config_str)
|
||||
return seed_config
|
||||
|
||||
|
||||
@@ -152,9 +165,35 @@ def _seed_settings(settings: Settings) -> None:
|
||||
|
||||
|
||||
def _seed_enterprise_settings(seed_config: SeedConfiguration) -> None:
|
||||
if seed_config.enterprise_settings is not None:
|
||||
if (
|
||||
seed_config.enterprise_settings is not None
|
||||
or seed_config.nav_item_overrides is not None
|
||||
):
|
||||
final_enterprise_settings = (
|
||||
deepcopy(seed_config.enterprise_settings)
|
||||
if seed_config.enterprise_settings
|
||||
else EnterpriseSettings()
|
||||
)
|
||||
|
||||
final_nav_items = final_enterprise_settings.custom_nav_items
|
||||
if seed_config.nav_item_overrides is not None:
|
||||
final_nav_items = []
|
||||
for item in seed_config.nav_item_overrides:
|
||||
with open(item.svg_path, "r") as file:
|
||||
svg_content = file.read().strip()
|
||||
|
||||
final_nav_items.append(
|
||||
NavigationItem(
|
||||
link=item.link,
|
||||
title=item.title,
|
||||
svg_logo=svg_content,
|
||||
)
|
||||
)
|
||||
|
||||
final_enterprise_settings.custom_nav_items = final_nav_items
|
||||
|
||||
logger.notice("Seeding enterprise settings")
|
||||
store_ee_settings(seed_config.enterprise_settings)
|
||||
store_ee_settings(final_enterprise_settings)
|
||||
|
||||
|
||||
def _seed_logo(db_session: Session, logo_path: str | None) -> None:
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
[pytest]
|
||||
pythonpath = .
|
||||
markers =
|
||||
slow: marks tests as slow
|
||||
slow: marks tests as slow
|
||||
filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
ignore::cryptography.utils.CryptographyDeprecationWarning
|
||||
|
||||
@@ -4,7 +4,7 @@ asyncpg==0.27.0
|
||||
atlassian-python-api==3.37.0
|
||||
beautifulsoup4==4.12.2
|
||||
boto3==1.34.84
|
||||
celery==5.3.4
|
||||
celery==5.5.0b4
|
||||
chardet==5.2.0
|
||||
dask==2023.8.1
|
||||
ddtrace==2.6.5
|
||||
|
||||
@@ -18,7 +18,8 @@ def monitor_process(process_name: str, process: subprocess.Popen) -> None:
|
||||
|
||||
|
||||
def run_jobs(exclude_indexing: bool) -> None:
|
||||
cmd_worker = [
|
||||
# command setup
|
||||
cmd_worker_primary = [
|
||||
"celery",
|
||||
"-A",
|
||||
"ee.danswer.background.celery.celery_app",
|
||||
@@ -26,8 +27,38 @@ def run_jobs(exclude_indexing: bool) -> None:
|
||||
"--pool=threads",
|
||||
"--concurrency=6",
|
||||
"--loglevel=INFO",
|
||||
"-n",
|
||||
"primary@%n",
|
||||
"-Q",
|
||||
"celery,vespa_metadata_sync,connector_deletion",
|
||||
"celery",
|
||||
]
|
||||
|
||||
cmd_worker_light = [
|
||||
"celery",
|
||||
"-A",
|
||||
"ee.danswer.background.celery.celery_app",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=16",
|
||||
"--loglevel=INFO",
|
||||
"-n",
|
||||
"light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion",
|
||||
]
|
||||
|
||||
cmd_worker_heavy = [
|
||||
"celery",
|
||||
"-A",
|
||||
"ee.danswer.background.celery.celery_app",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=6",
|
||||
"--loglevel=INFO",
|
||||
"-n",
|
||||
"heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning",
|
||||
]
|
||||
|
||||
cmd_beat = [
|
||||
@@ -38,19 +69,38 @@ def run_jobs(exclude_indexing: bool) -> None:
|
||||
"--loglevel=INFO",
|
||||
]
|
||||
|
||||
worker_process = subprocess.Popen(
|
||||
cmd_worker, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
# spawn processes
|
||||
worker_primary_process = subprocess.Popen(
|
||||
cmd_worker_primary, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
|
||||
worker_light_process = subprocess.Popen(
|
||||
cmd_worker_light, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
|
||||
worker_heavy_process = subprocess.Popen(
|
||||
cmd_worker_heavy, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
|
||||
beat_process = subprocess.Popen(
|
||||
cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(
|
||||
target=monitor_process, args=("WORKER", worker_process)
|
||||
# monitor threads
|
||||
worker_primary_thread = threading.Thread(
|
||||
target=monitor_process, args=("PRIMARY", worker_primary_process)
|
||||
)
|
||||
worker_light_thread = threading.Thread(
|
||||
target=monitor_process, args=("LIGHT", worker_light_process)
|
||||
)
|
||||
worker_heavy_thread = threading.Thread(
|
||||
target=monitor_process, args=("HEAVY", worker_heavy_process)
|
||||
)
|
||||
beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process))
|
||||
|
||||
worker_thread.start()
|
||||
worker_primary_thread.start()
|
||||
worker_light_thread.start()
|
||||
worker_heavy_thread.start()
|
||||
beat_thread.start()
|
||||
|
||||
if not exclude_indexing:
|
||||
@@ -93,7 +143,9 @@ def run_jobs(exclude_indexing: bool) -> None:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
worker_thread.join()
|
||||
worker_primary_thread.join()
|
||||
worker_light_thread.join()
|
||||
worker_heavy_thread.join()
|
||||
beat_thread.join()
|
||||
|
||||
|
||||
|
||||
@@ -24,23 +24,59 @@ autorestart=true
|
||||
# on a system, but this should be okay for now since all our celery tasks are
|
||||
# relatively compute-light (e.g. they tend to just make a bunch of requests to
|
||||
# Vespa / Postgres)
|
||||
[program:celery_worker]
|
||||
[program:celery_worker_primary]
|
||||
command=celery -A danswer.background.celery.celery_run:celery_app worker
|
||||
--pool=threads
|
||||
--concurrency=6
|
||||
--concurrency=4
|
||||
--prefetch-multiplier=1
|
||||
--loglevel=INFO
|
||||
--logfile=/var/log/celery_worker_supervisor.log
|
||||
-Q celery,vespa_metadata_sync,connector_deletion
|
||||
environment=LOG_FILE_NAME=celery_worker
|
||||
--hostname=primary@%%n
|
||||
-Q celery
|
||||
stdout_logfile=/var/log/celery_worker_primary.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
|
||||
[program:celery_worker_light]
|
||||
command=bash -c "celery -A danswer.background.celery.celery_run:celery_app worker \
|
||||
--pool=threads \
|
||||
--concurrency=${CELERY_WORKER_LIGHT_CONCURRENCY:-24} \
|
||||
--prefetch-multiplier=${CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER:-8} \
|
||||
--loglevel=INFO \
|
||||
--hostname=light@%%n \
|
||||
-Q vespa_metadata_sync,connector_deletion"
|
||||
stdout_logfile=/var/log/celery_worker_light.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
|
||||
[program:celery_worker_heavy]
|
||||
command=celery -A danswer.background.celery.celery_run:celery_app worker
|
||||
--pool=threads
|
||||
--concurrency=4
|
||||
--prefetch-multiplier=1
|
||||
--loglevel=INFO
|
||||
--hostname=heavy@%%n
|
||||
-Q connector_pruning
|
||||
stdout_logfile=/var/log/celery_worker_heavy.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
|
||||
# Job scheduler for periodic tasks
|
||||
[program:celery_beat]
|
||||
command=celery -A danswer.background.celery.celery_run:celery_app beat
|
||||
--logfile=/var/log/celery_beat_supervisor.log
|
||||
environment=LOG_FILE_NAME=celery_beat
|
||||
command=celery -A danswer.background.celery.celery_run:celery_app beat
|
||||
stdout_logfile=/var/log/celery_beat.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
|
||||
# Listens for Slack messages and responds with answers
|
||||
# for all channels that the DanswerBot has been added to.
|
||||
@@ -58,13 +94,12 @@ startsecs=60
|
||||
# No log rotation here, since it's stdout it's handled by the Docker container logging
|
||||
[program:log-redirect-handler]
|
||||
command=tail -qF
|
||||
/var/log/celery_beat.log
|
||||
/var/log/celery_worker_primary.log
|
||||
/var/log/celery_worker_light.log
|
||||
/var/log/celery_worker_heavy.log
|
||||
/var/log/document_indexing_info.log
|
||||
/var/log/celery_beat_supervisor.log
|
||||
/var/log/celery_worker_supervisor.log
|
||||
/var/log/celery_beat_debug.log
|
||||
/var/log/celery_worker_debug.log
|
||||
/var/log/slack_bot_debug.log
|
||||
stdout_logfile=/dev/stdout
|
||||
stdout_logfile_maxbytes=0
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
stdout_logfile_maxbytes = 0 # must be set to 0 when stdout_logfile=/dev/stdout
|
||||
autorestart=true
|
||||
|
||||
48
backend/tests/daily/connectors/jira/test_jira_basic.py
Normal file
48
backend/tests/daily/connectors/jira/test_jira_basic.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.danswer_jira.connector import JiraConnector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jira_connector() -> JiraConnector:
|
||||
connector = JiraConnector(
|
||||
"https://danswerai.atlassian.net/jira/software/c/projects/AS/boards/6",
|
||||
comment_email_blacklist=[],
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"jira_user_email": os.environ["JIRA_USER_EMAIL"],
|
||||
"jira_api_token": os.environ["JIRA_API_TOKEN"],
|
||||
}
|
||||
)
|
||||
return connector
|
||||
|
||||
|
||||
def test_jira_connector_basic(jira_connector: JiraConnector) -> None:
|
||||
doc_batch_generator = jira_connector.poll_source(0, time.time())
|
||||
|
||||
doc_batch = next(doc_batch_generator)
|
||||
with pytest.raises(StopIteration):
|
||||
next(doc_batch_generator)
|
||||
|
||||
assert len(doc_batch) == 1
|
||||
|
||||
doc = doc_batch[0]
|
||||
|
||||
assert doc.id == "https://danswerai.atlassian.net/browse/AS-2"
|
||||
assert doc.semantic_identifier == "test123small"
|
||||
assert doc.source == DocumentSource.JIRA
|
||||
assert doc.metadata == {"priority": "Medium", "status": "Backlog"}
|
||||
assert doc.secondary_owners is None
|
||||
assert doc.title is None
|
||||
assert doc.from_ingestion_api is False
|
||||
assert doc.additional_info is None
|
||||
|
||||
assert len(doc.sections) == 1
|
||||
section = doc.sections[0]
|
||||
assert section.text == "example_text\n"
|
||||
assert section.link == "https://danswerai.atlassian.net/browse/AS-2"
|
||||
@@ -72,6 +72,7 @@ COPY ./danswer /app/danswer
|
||||
COPY ./shared_configs /app/shared_configs
|
||||
COPY ./alembic /app/alembic
|
||||
COPY ./alembic.ini /app/alembic.ini
|
||||
COPY ./pytest.ini /app/pytest.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
|
||||
# Integration test stuff
|
||||
|
||||
@@ -6,8 +6,8 @@ from danswer.db.models import UserRole
|
||||
from ee.danswer.server.api_key.models import APIKeyArgs
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import TestAPIKey
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
from tests.integration.common_utils.test_models import DATestAPIKey
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
class APIKeyManager:
|
||||
@@ -15,8 +15,8 @@ class APIKeyManager:
|
||||
def create(
|
||||
name: str | None = None,
|
||||
api_key_role: UserRole = UserRole.ADMIN,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestAPIKey:
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestAPIKey:
|
||||
name = f"{name}-api-key" if name else f"test-api-key-{uuid4()}"
|
||||
api_key_request = APIKeyArgs(
|
||||
name=name,
|
||||
@@ -31,7 +31,7 @@ class APIKeyManager:
|
||||
)
|
||||
api_key_response.raise_for_status()
|
||||
api_key = api_key_response.json()
|
||||
result_api_key = TestAPIKey(
|
||||
result_api_key = DATestAPIKey(
|
||||
api_key_id=api_key["api_key_id"],
|
||||
api_key_display=api_key["api_key_display"],
|
||||
api_key=api_key["api_key"],
|
||||
@@ -45,8 +45,8 @@ class APIKeyManager:
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
api_key: TestAPIKey,
|
||||
user_performing_action: TestUser | None = None,
|
||||
api_key: DATestAPIKey,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
api_key_response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/api-key/{api_key.api_key_id}",
|
||||
@@ -58,8 +58,8 @@ class APIKeyManager:
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> list[TestAPIKey]:
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> list[DATestAPIKey]:
|
||||
api_key_response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/api-key",
|
||||
headers=user_performing_action.headers
|
||||
@@ -67,13 +67,13 @@ class APIKeyManager:
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
api_key_response.raise_for_status()
|
||||
return [TestAPIKey(**api_key) for api_key in api_key_response.json()]
|
||||
return [DATestAPIKey(**api_key) for api_key in api_key_response.json()]
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
api_key: TestAPIKey,
|
||||
api_key: DATestAPIKey,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: TestUser | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
retrieved_keys = APIKeyManager.get_all(
|
||||
user_performing_action=user_performing_action
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -7,6 +8,8 @@ import requests
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.enums import TaskStatus
|
||||
from danswer.server.documents.models import CCPairPruningTask
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.server.documents.models import ConnectorIndexingStatus
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
@@ -15,8 +18,8 @@ from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
from tests.integration.common_utils.test_models import TestCCPair
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _cc_pair_creator(
|
||||
@@ -25,8 +28,8 @@ def _cc_pair_creator(
|
||||
name: str | None = None,
|
||||
access_type: AccessType = AccessType.PUBLIC,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestCCPair:
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestCCPair:
|
||||
name = f"{name}-cc-pair" if name else f"test-cc-pair-{uuid4()}"
|
||||
|
||||
request = {
|
||||
@@ -43,7 +46,7 @@ def _cc_pair_creator(
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return TestCCPair(
|
||||
return DATestCCPair(
|
||||
id=response.json()["data"],
|
||||
name=name,
|
||||
connector_id=connector_id,
|
||||
@@ -63,8 +66,8 @@ class CCPairManager:
|
||||
input_type: InputType = InputType.LOAD_STATE,
|
||||
connector_specific_config: dict[str, Any] | None = None,
|
||||
credential_json: dict[str, Any] | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestCCPair:
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestCCPair:
|
||||
connector = ConnectorManager.create(
|
||||
name=name,
|
||||
source=source,
|
||||
@@ -98,8 +101,8 @@ class CCPairManager:
|
||||
name: str | None = None,
|
||||
access_type: AccessType = AccessType.PUBLIC,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestCCPair:
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestCCPair:
|
||||
return _cc_pair_creator(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
@@ -111,8 +114,8 @@ class CCPairManager:
|
||||
|
||||
@staticmethod
|
||||
def pause_cc_pair(
|
||||
cc_pair: TestCCPair,
|
||||
user_performing_action: TestUser | None = None,
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
result = requests.put(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/status",
|
||||
@@ -125,8 +128,8 @@ class CCPairManager:
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
cc_pair: TestCCPair,
|
||||
user_performing_action: TestUser | None = None,
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
cc_pair_identifier = ConnectorCredentialPairIdentifier(
|
||||
connector_id=cc_pair.connector_id,
|
||||
@@ -141,9 +144,28 @@ class CCPairManager:
|
||||
)
|
||||
result.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_one(
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> ConnectorIndexingStatus | None:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/connector/indexing-status",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
for cc_pair_json in response.json():
|
||||
cc_pair = ConnectorIndexingStatus(**cc_pair_json)
|
||||
if cc_pair.cc_pair_id == cc_pair_id:
|
||||
return cc_pair
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: TestUser | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> list[ConnectorIndexingStatus]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/connector/indexing-status",
|
||||
@@ -156,9 +178,9 @@ class CCPairManager:
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
cc_pair: TestCCPair,
|
||||
cc_pair: DATestCCPair,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: TestUser | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
all_cc_pairs = CCPairManager.get_all(user_performing_action)
|
||||
for retrieved_cc_pair in all_cc_pairs:
|
||||
@@ -182,10 +204,99 @@ class CCPairManager:
|
||||
raise ValueError(f"CC pair {cc_pair.id} not found")
|
||||
|
||||
@staticmethod
|
||||
def wait_for_deletion_completion(
|
||||
user_performing_action: TestUser | None = None,
|
||||
def wait_for_indexing(
|
||||
cc_pair_test: DATestCCPair,
|
||||
after: datetime,
|
||||
timeout: float = MAX_DELAY,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
start = time.time()
|
||||
"""after: Wait for an indexing success time after this time"""
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
cc_pairs = CCPairManager.get_all(user_performing_action)
|
||||
for cc_pair in cc_pairs:
|
||||
if cc_pair.cc_pair_id != cc_pair_test.id:
|
||||
continue
|
||||
|
||||
if cc_pair.last_success and cc_pair.last_success > after:
|
||||
print(f"cc_pair {cc_pair_test.id} indexing complete.")
|
||||
return
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed > timeout:
|
||||
raise TimeoutError(
|
||||
f"CC pair indexing was not completed within {timeout} seconds"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Waiting for CC indexing to complete. elapsed={elapsed:.2f} timeout={timeout}"
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
@staticmethod
|
||||
def prune(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/prune",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
result.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_prune_task(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> CCPairPruningTask:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/prune",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return CCPairPruningTask(**response.json())
|
||||
|
||||
@staticmethod
|
||||
def wait_for_prune(
|
||||
cc_pair_test: DATestCCPair,
|
||||
after: datetime,
|
||||
timeout: float = MAX_DELAY,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""after: The task register time must be after this time."""
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
task = CCPairManager.get_prune_task(cc_pair_test, user_performing_action)
|
||||
if not task:
|
||||
raise ValueError("Prune task not found.")
|
||||
|
||||
if not task.register_time or task.register_time < after:
|
||||
raise ValueError("Prune task register time is too early.")
|
||||
|
||||
if task.status == TaskStatus.SUCCESS:
|
||||
# Pruning succeeded
|
||||
return
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed > timeout:
|
||||
raise TimeoutError(
|
||||
f"CC pair pruning was not completed within {timeout} seconds"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Waiting for CC pruning to complete. elapsed={elapsed:.2f} timeout={timeout}"
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
@staticmethod
|
||||
def wait_for_deletion_completion(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
cc_pairs = CCPairManager.get_all(user_performing_action)
|
||||
if all(
|
||||
@@ -194,7 +305,7 @@ class CCPairManager:
|
||||
):
|
||||
return
|
||||
|
||||
if time.time() - start > MAX_DELAY:
|
||||
if time.monotonic() - start > MAX_DELAY:
|
||||
raise TimeoutError(
|
||||
f"CC pairs deletion was not completed within the {MAX_DELAY} seconds"
|
||||
)
|
||||
|
||||
@@ -13,10 +13,10 @@ from danswer.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestChatMessage
|
||||
from tests.integration.common_utils.test_models import DATestChatSession
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.test_models import StreamedResponse
|
||||
from tests.integration.common_utils.test_models import TestChatMessage
|
||||
from tests.integration.common_utils.test_models import TestChatSession
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
|
||||
|
||||
class ChatSessionManager:
|
||||
@@ -24,8 +24,8 @@ class ChatSessionManager:
|
||||
def create(
|
||||
persona_id: int = -1,
|
||||
description: str = "Test chat session",
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestChatSession:
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestChatSession:
|
||||
chat_session_creation_req = ChatSessionCreationRequest(
|
||||
persona_id=persona_id, description=description
|
||||
)
|
||||
@@ -38,7 +38,7 @@ class ChatSessionManager:
|
||||
)
|
||||
response.raise_for_status()
|
||||
chat_session_id = response.json()["chat_session_id"]
|
||||
return TestChatSession(
|
||||
return DATestChatSession(
|
||||
id=chat_session_id, persona_id=persona_id, description=description
|
||||
)
|
||||
|
||||
@@ -47,7 +47,7 @@ class ChatSessionManager:
|
||||
chat_session_id: int,
|
||||
message: str,
|
||||
parent_message_id: int | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
file_descriptors: list[FileDescriptor] = [],
|
||||
prompt_id: int | None = None,
|
||||
search_doc_ids: list[int] | None = None,
|
||||
@@ -90,7 +90,7 @@ class ChatSessionManager:
|
||||
def get_answer_with_quote(
|
||||
persona_id: int,
|
||||
message: str,
|
||||
user_performing_action: TestUser | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> StreamedResponse:
|
||||
direct_qa_request = DirectQARequest(
|
||||
messages=[ThreadMessage(message=message)],
|
||||
@@ -137,9 +137,9 @@ class ChatSessionManager:
|
||||
|
||||
@staticmethod
|
||||
def get_chat_history(
|
||||
chat_session: TestChatSession,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> list[TestChatMessage]:
|
||||
chat_session: DATestChatSession,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> list[DATestChatMessage]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/chat/history/{chat_session.id}",
|
||||
headers=user_performing_action.headers
|
||||
@@ -149,7 +149,7 @@ class ChatSessionManager:
|
||||
response.raise_for_status()
|
||||
|
||||
return [
|
||||
TestChatMessage(
|
||||
DATestChatMessage(
|
||||
id=msg["id"],
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message_id=msg.get("parent_message_id"),
|
||||
|
||||
@@ -8,8 +8,8 @@ from danswer.server.documents.models import ConnectorUpdateRequest
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import TestConnector
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
from tests.integration.common_utils.test_models import DATestConnector
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
class ConnectorManager:
|
||||
@@ -21,8 +21,8 @@ class ConnectorManager:
|
||||
connector_specific_config: dict[str, Any] | None = None,
|
||||
is_public: bool = True,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestConnector:
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestConnector:
|
||||
name = f"{name}-connector" if name else f"test-connector-{uuid4()}"
|
||||
|
||||
connector_update_request = ConnectorUpdateRequest(
|
||||
@@ -44,7 +44,7 @@ class ConnectorManager:
|
||||
response.raise_for_status()
|
||||
|
||||
response_data = response.json()
|
||||
return TestConnector(
|
||||
return DATestConnector(
|
||||
id=response_data.get("id"),
|
||||
name=name,
|
||||
source=source,
|
||||
@@ -56,8 +56,8 @@ class ConnectorManager:
|
||||
|
||||
@staticmethod
|
||||
def edit(
|
||||
connector: TestConnector,
|
||||
user_performing_action: TestUser | None = None,
|
||||
connector: DATestConnector,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
response = requests.patch(
|
||||
url=f"{API_SERVER_URL}/manage/admin/connector/{connector.id}",
|
||||
@@ -70,8 +70,8 @@ class ConnectorManager:
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
connector: TestConnector,
|
||||
user_performing_action: TestUser | None = None,
|
||||
connector: DATestConnector,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
response = requests.delete(
|
||||
url=f"{API_SERVER_URL}/manage/admin/connector/{connector.id}",
|
||||
@@ -83,8 +83,8 @@ class ConnectorManager:
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> list[TestConnector]:
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> list[DATestConnector]:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/connector",
|
||||
headers=user_performing_action.headers
|
||||
@@ -93,7 +93,7 @@ class ConnectorManager:
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [
|
||||
TestConnector(
|
||||
DATestConnector(
|
||||
id=conn.get("id"),
|
||||
name=conn.get("name", ""),
|
||||
source=conn.get("source", DocumentSource.FILE),
|
||||
@@ -105,8 +105,8 @@ class ConnectorManager:
|
||||
|
||||
@staticmethod
|
||||
def get(
|
||||
connector_id: int, user_performing_action: TestUser | None = None
|
||||
) -> TestConnector:
|
||||
connector_id: int, user_performing_action: DATestUser | None = None
|
||||
) -> DATestConnector:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/connector/{connector_id}",
|
||||
headers=user_performing_action.headers
|
||||
@@ -115,7 +115,7 @@ class ConnectorManager:
|
||||
)
|
||||
response.raise_for_status()
|
||||
conn = response.json()
|
||||
return TestConnector(
|
||||
return DATestConnector(
|
||||
id=conn.get("id"),
|
||||
name=conn.get("name", ""),
|
||||
source=conn.get("source", DocumentSource.FILE),
|
||||
|
||||
@@ -7,8 +7,8 @@ from danswer.server.documents.models import CredentialSnapshot
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import TestCredential
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
from tests.integration.common_utils.test_models import DATestCredential
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
class CredentialManager:
|
||||
@@ -20,8 +20,8 @@ class CredentialManager:
|
||||
source: DocumentSource = DocumentSource.FILE,
|
||||
curator_public: bool = True,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestCredential:
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestCredential:
|
||||
name = f"{name}-credential" if name else f"test-credential-{uuid4()}"
|
||||
|
||||
credential_request = {
|
||||
@@ -41,7 +41,7 @@ class CredentialManager:
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
return TestCredential(
|
||||
return DATestCredential(
|
||||
id=response.json()["id"],
|
||||
name=name,
|
||||
credential_json=credential_json or {},
|
||||
@@ -53,8 +53,8 @@ class CredentialManager:
|
||||
|
||||
@staticmethod
|
||||
def edit(
|
||||
credential: TestCredential,
|
||||
user_performing_action: TestUser | None = None,
|
||||
credential: DATestCredential,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
request = credential.model_dump(include={"name", "credential_json"})
|
||||
response = requests.put(
|
||||
@@ -68,8 +68,8 @@ class CredentialManager:
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
credential: TestCredential,
|
||||
user_performing_action: TestUser | None = None,
|
||||
credential: DATestCredential,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
response = requests.delete(
|
||||
url=f"{API_SERVER_URL}/manage/credential/{credential.id}",
|
||||
@@ -81,7 +81,7 @@ class CredentialManager:
|
||||
|
||||
@staticmethod
|
||||
def get(
|
||||
credential_id: int, user_performing_action: TestUser | None = None
|
||||
credential_id: int, user_performing_action: DATestUser | None = None
|
||||
) -> CredentialSnapshot:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/credential/{credential_id}",
|
||||
@@ -94,7 +94,7 @@ class CredentialManager:
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: TestUser | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> list[CredentialSnapshot]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/credential",
|
||||
@@ -107,9 +107,9 @@ class CredentialManager:
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
credential: TestCredential,
|
||||
credential: DATestCredential,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: TestUser | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
all_credentials = CredentialManager.get_all(user_performing_action)
|
||||
for fetched_credential in all_credentials:
|
||||
|
||||
@@ -7,19 +7,19 @@ from danswer.db.enums import AccessType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import NUM_DOCS
|
||||
from tests.integration.common_utils.managers.api_key import TestAPIKey
|
||||
from tests.integration.common_utils.managers.cc_pair import TestCCPair
|
||||
from tests.integration.common_utils.managers.api_key import DATestAPIKey
|
||||
from tests.integration.common_utils.managers.cc_pair import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.test_models import SimpleTestDocument
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
from tests.integration.common_utils.vespa import TestVespaClient
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
|
||||
def _verify_document_permissions(
|
||||
retrieved_doc: dict,
|
||||
cc_pair: TestCCPair,
|
||||
cc_pair: DATestCCPair,
|
||||
doc_set_names: list[str] | None = None,
|
||||
group_names: list[str] | None = None,
|
||||
doc_creating_user: TestUser | None = None,
|
||||
doc_creating_user: DATestUser | None = None,
|
||||
) -> None:
|
||||
acl_keys = set(retrieved_doc["access_control_list"].keys())
|
||||
print(f"ACL keys: {acl_keys}")
|
||||
@@ -83,10 +83,10 @@ def _generate_dummy_document(
|
||||
class DocumentManager:
|
||||
@staticmethod
|
||||
def seed_dummy_docs(
|
||||
cc_pair: TestCCPair,
|
||||
cc_pair: DATestCCPair,
|
||||
num_docs: int = NUM_DOCS,
|
||||
document_ids: list[str] | None = None,
|
||||
api_key: TestAPIKey | None = None,
|
||||
api_key: DATestAPIKey | None = None,
|
||||
) -> list[SimpleTestDocument]:
|
||||
# Use provided document_ids if available, otherwise generate random UUIDs
|
||||
if document_ids is None:
|
||||
@@ -116,10 +116,10 @@ class DocumentManager:
|
||||
|
||||
@staticmethod
|
||||
def seed_doc_with_content(
|
||||
cc_pair: TestCCPair,
|
||||
cc_pair: DATestCCPair,
|
||||
content: str,
|
||||
document_id: str | None = None,
|
||||
api_key: TestAPIKey | None = None,
|
||||
api_key: DATestAPIKey | None = None,
|
||||
) -> SimpleTestDocument:
|
||||
# Use provided document_ids if available, otherwise generate random UUIDs
|
||||
if document_id is None:
|
||||
@@ -142,13 +142,13 @@ class DocumentManager:
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
vespa_client: TestVespaClient,
|
||||
cc_pair: TestCCPair,
|
||||
vespa_client: vespa_fixture,
|
||||
cc_pair: DATestCCPair,
|
||||
# If None, will not check doc sets or groups
|
||||
# If empty list, will check for empty doc sets or groups
|
||||
doc_set_names: list[str] | None = None,
|
||||
group_names: list[str] | None = None,
|
||||
doc_creating_user: TestUser | None = None,
|
||||
doc_creating_user: DATestUser | None = None,
|
||||
verify_deleted: bool = False,
|
||||
) -> None:
|
||||
doc_ids = [document.id for document in cc_pair.documents]
|
||||
|
||||
@@ -6,8 +6,8 @@ import requests
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.test_models import TestDocumentSet
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
from tests.integration.common_utils.test_models import DATestDocumentSet
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
class DocumentSetManager:
|
||||
@@ -19,8 +19,8 @@ class DocumentSetManager:
|
||||
is_public: bool = True,
|
||||
users: list[str] | None = None,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestDocumentSet:
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestDocumentSet:
|
||||
if name is None:
|
||||
name = f"test_doc_set_{str(uuid4())}"
|
||||
|
||||
@@ -42,7 +42,7 @@ class DocumentSetManager:
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return TestDocumentSet(
|
||||
return DATestDocumentSet(
|
||||
id=int(response.json()),
|
||||
name=name,
|
||||
description=description or name,
|
||||
@@ -55,8 +55,8 @@ class DocumentSetManager:
|
||||
|
||||
@staticmethod
|
||||
def edit(
|
||||
document_set: TestDocumentSet,
|
||||
user_performing_action: TestUser | None = None,
|
||||
document_set: DATestDocumentSet,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> bool:
|
||||
doc_set_update_request = {
|
||||
"id": document_set.id,
|
||||
@@ -78,8 +78,8 @@ class DocumentSetManager:
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
document_set: TestDocumentSet,
|
||||
user_performing_action: TestUser | None = None,
|
||||
document_set: DATestDocumentSet,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> bool:
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/manage/admin/document-set/{document_set.id}",
|
||||
@@ -92,8 +92,8 @@ class DocumentSetManager:
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> list[TestDocumentSet]:
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> list[DATestDocumentSet]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/document-set",
|
||||
headers=user_performing_action.headers
|
||||
@@ -102,7 +102,7 @@ class DocumentSetManager:
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [
|
||||
TestDocumentSet(
|
||||
DATestDocumentSet(
|
||||
id=doc_set["id"],
|
||||
name=doc_set["name"],
|
||||
description=doc_set["description"],
|
||||
@@ -119,8 +119,8 @@ class DocumentSetManager:
|
||||
|
||||
@staticmethod
|
||||
def wait_for_sync(
|
||||
document_sets_to_check: list[TestDocumentSet] | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
document_sets_to_check: list[DATestDocumentSet] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
# wait for document sets to be synced
|
||||
start = time.time()
|
||||
@@ -148,9 +148,9 @@ class DocumentSetManager:
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
document_set: TestDocumentSet,
|
||||
document_set: DATestDocumentSet,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: TestUser | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
doc_sets = DocumentSetManager.get_all(user_performing_action)
|
||||
for doc_set in doc_sets:
|
||||
|
||||
@@ -3,11 +3,12 @@ from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.server.manage.llm.models import FullLLMProvider
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import TestLLMProvider
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
class LLMProviderManager:
|
||||
@@ -21,8 +22,8 @@ class LLMProviderManager:
|
||||
api_version: str | None = None,
|
||||
groups: list[int] | None = None,
|
||||
is_public: bool | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestLLMProvider:
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestLLMProvider:
|
||||
print("Seeding LLM Providers...")
|
||||
|
||||
llm_provider = LLMProviderUpsertRequest(
|
||||
@@ -49,7 +50,10 @@ class LLMProviderManager:
|
||||
)
|
||||
llm_response.raise_for_status()
|
||||
response_data = llm_response.json()
|
||||
result_llm = TestLLMProvider(
|
||||
import json
|
||||
|
||||
print(json.dumps(response_data, indent=4))
|
||||
result_llm = DATestLLMProvider(
|
||||
id=response_data["id"],
|
||||
name=response_data["name"],
|
||||
provider=response_data["provider"],
|
||||
@@ -73,11 +77,9 @@ class LLMProviderManager:
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
llm_provider: TestLLMProvider,
|
||||
user_performing_action: TestUser | None = None,
|
||||
llm_provider: DATestLLMProvider,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> bool:
|
||||
if not llm_provider.id:
|
||||
raise ValueError("LLM Provider ID is required to delete a provider")
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{llm_provider.id}",
|
||||
headers=user_performing_action.headers
|
||||
@@ -86,3 +88,43 @@ class LLMProviderManager:
|
||||
)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> list[FullLLMProvider]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [FullLLMProvider(**ug) for ug in response.json()]
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
llm_provider: DATestLLMProvider,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
all_llm_providers = LLMProviderManager.get_all(user_performing_action)
|
||||
for fetched_llm_provider in all_llm_providers:
|
||||
if llm_provider.id == fetched_llm_provider.id:
|
||||
if verify_deleted:
|
||||
raise ValueError(
|
||||
f"User group {llm_provider.id} found but should be deleted"
|
||||
)
|
||||
fetched_llm_groups = set(fetched_llm_provider.groups)
|
||||
llm_provider_groups = set(llm_provider.groups)
|
||||
if (
|
||||
fetched_llm_groups == llm_provider_groups
|
||||
and llm_provider.provider == fetched_llm_provider.provider
|
||||
and llm_provider.api_key == fetched_llm_provider.api_key
|
||||
and llm_provider.default_model_name
|
||||
== fetched_llm_provider.default_model_name
|
||||
and llm_provider.is_public == fetched_llm_provider.is_public
|
||||
):
|
||||
return
|
||||
if not verify_deleted:
|
||||
raise ValueError(f"User group {llm_provider.id} not found")
|
||||
@@ -6,8 +6,8 @@ from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.server.features.persona.models import PersonaSnapshot
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import TestPersona
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
from tests.integration.common_utils.test_models import DATestPersona
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
class PersonaManager:
|
||||
@@ -27,8 +27,8 @@ class PersonaManager:
|
||||
llm_model_version_override: str | None = None,
|
||||
users: list[str] | None = None,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestPersona:
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestPersona:
|
||||
name = name or f"test-persona-{uuid4()}"
|
||||
description = description or f"Description for {name}"
|
||||
|
||||
@@ -59,7 +59,7 @@ class PersonaManager:
|
||||
response.raise_for_status()
|
||||
persona_data = response.json()
|
||||
|
||||
return TestPersona(
|
||||
return DATestPersona(
|
||||
id=persona_data["id"],
|
||||
name=name,
|
||||
description=description,
|
||||
@@ -79,7 +79,7 @@ class PersonaManager:
|
||||
|
||||
@staticmethod
|
||||
def edit(
|
||||
persona: TestPersona,
|
||||
persona: DATestPersona,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
num_chunks: float | None = None,
|
||||
@@ -94,8 +94,8 @@ class PersonaManager:
|
||||
llm_model_version_override: str | None = None,
|
||||
users: list[str] | None = None,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestPersona:
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestPersona:
|
||||
persona_update_request = {
|
||||
"name": name or persona.name,
|
||||
"description": description or persona.description,
|
||||
@@ -127,7 +127,7 @@ class PersonaManager:
|
||||
response.raise_for_status()
|
||||
updated_persona_data = response.json()
|
||||
|
||||
return TestPersona(
|
||||
return DATestPersona(
|
||||
id=updated_persona_data["id"],
|
||||
name=updated_persona_data["name"],
|
||||
description=updated_persona_data["description"],
|
||||
@@ -151,7 +151,7 @@ class PersonaManager:
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: TestUser | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> list[PersonaSnapshot]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/persona",
|
||||
@@ -164,38 +164,46 @@ class PersonaManager:
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
test_persona: TestPersona,
|
||||
user_performing_action: TestUser | None = None,
|
||||
persona: DATestPersona,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> bool:
|
||||
all_personas = PersonaManager.get_all(user_performing_action)
|
||||
for persona in all_personas:
|
||||
if persona.id == test_persona.id:
|
||||
for fetched_persona in all_personas:
|
||||
if fetched_persona.id == persona.id:
|
||||
return (
|
||||
persona.name == test_persona.name
|
||||
and persona.description == test_persona.description
|
||||
and persona.num_chunks == test_persona.num_chunks
|
||||
and persona.llm_relevance_filter
|
||||
== test_persona.llm_relevance_filter
|
||||
and persona.is_public == test_persona.is_public
|
||||
and persona.llm_filter_extraction
|
||||
== test_persona.llm_filter_extraction
|
||||
and persona.llm_model_provider_override
|
||||
== test_persona.llm_model_provider_override
|
||||
and persona.llm_model_version_override
|
||||
== test_persona.llm_model_version_override
|
||||
and set(persona.prompts) == set(test_persona.prompt_ids)
|
||||
and set(persona.document_sets) == set(test_persona.document_set_ids)
|
||||
and set(persona.tools) == set(test_persona.tool_ids)
|
||||
and set(user.email for user in persona.users)
|
||||
== set(test_persona.users)
|
||||
and set(persona.groups) == set(test_persona.groups)
|
||||
fetched_persona.name == persona.name
|
||||
and fetched_persona.description == persona.description
|
||||
and fetched_persona.num_chunks == persona.num_chunks
|
||||
and fetched_persona.llm_relevance_filter
|
||||
== persona.llm_relevance_filter
|
||||
and fetched_persona.is_public == persona.is_public
|
||||
and fetched_persona.llm_filter_extraction
|
||||
== persona.llm_filter_extraction
|
||||
and fetched_persona.llm_model_provider_override
|
||||
== persona.llm_model_provider_override
|
||||
and fetched_persona.llm_model_version_override
|
||||
== persona.llm_model_version_override
|
||||
and set([prompt.id for prompt in fetched_persona.prompts])
|
||||
== set(persona.prompt_ids)
|
||||
and set(
|
||||
[
|
||||
document_set.id
|
||||
for document_set in fetched_persona.document_sets
|
||||
]
|
||||
)
|
||||
== set(persona.document_set_ids)
|
||||
and set([tool.id for tool in fetched_persona.tools])
|
||||
== set(persona.tool_ids)
|
||||
and set(user.email for user in fetched_persona.users)
|
||||
== set(persona.users)
|
||||
and set(fetched_persona.groups) == set(persona.groups)
|
||||
)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
persona: TestPersona,
|
||||
user_performing_action: TestUser | None = None,
|
||||
persona: DATestPersona,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> bool:
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/persona/{persona.id}",
|
||||
|
||||
@@ -10,14 +10,14 @@ from danswer.server.models import FullUserSnapshot
|
||||
from danswer.server.models import InvitedUserSnapshot
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
class UserManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
name: str | None = None,
|
||||
) -> TestUser:
|
||||
) -> DATestUser:
|
||||
if name is None:
|
||||
name = f"test{str(uuid4())}"
|
||||
|
||||
@@ -36,7 +36,7 @@ class UserManager:
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
test_user = TestUser(
|
||||
test_user = DATestUser(
|
||||
id=response.json()["id"],
|
||||
email=email,
|
||||
password=password,
|
||||
@@ -49,7 +49,7 @@ class UserManager:
|
||||
return test_user
|
||||
|
||||
@staticmethod
|
||||
def login_as_user(test_user: TestUser) -> str:
|
||||
def login_as_user(test_user: DATestUser) -> str:
|
||||
data = urlencode(
|
||||
{
|
||||
"username": test_user.email,
|
||||
@@ -74,7 +74,7 @@ class UserManager:
|
||||
return f"{result_cookie.name}={result_cookie.value}"
|
||||
|
||||
@staticmethod
|
||||
def verify_role(user_to_verify: TestUser, target_role: UserRole) -> bool:
|
||||
def verify_role(user_to_verify: DATestUser, target_role: UserRole) -> bool:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/me",
|
||||
headers=user_to_verify.headers,
|
||||
@@ -84,9 +84,9 @@ class UserManager:
|
||||
|
||||
@staticmethod
|
||||
def set_role(
|
||||
user_to_set: TestUser,
|
||||
user_to_set: DATestUser,
|
||||
target_role: UserRole,
|
||||
user_to_perform_action: TestUser | None = None,
|
||||
user_to_perform_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
if user_to_perform_action is None:
|
||||
user_to_perform_action = user_to_set
|
||||
@@ -98,7 +98,9 @@ class UserManager:
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def verify(user: TestUser, user_to_perform_action: TestUser | None = None) -> None:
|
||||
def verify(
|
||||
user: DATestUser, user_to_perform_action: DATestUser | None = None
|
||||
) -> None:
|
||||
if user_to_perform_action is None:
|
||||
user_to_perform_action = user
|
||||
response = requests.get(
|
||||
|
||||
@@ -7,8 +7,8 @@ from ee.danswer.server.user_group.models import UserGroup
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
from tests.integration.common_utils.test_models import TestUserGroup
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.test_models import DATestUserGroup
|
||||
|
||||
|
||||
class UserGroupManager:
|
||||
@@ -17,8 +17,8 @@ class UserGroupManager:
|
||||
name: str | None = None,
|
||||
user_ids: list[str] | None = None,
|
||||
cc_pair_ids: list[int] | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestUserGroup:
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestUserGroup:
|
||||
name = f"{name}-user-group" if name else f"test-user-group-{uuid4()}"
|
||||
|
||||
request = {
|
||||
@@ -34,7 +34,7 @@ class UserGroupManager:
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
test_user_group = TestUserGroup(
|
||||
test_user_group = DATestUserGroup(
|
||||
id=response.json()["id"],
|
||||
name=response.json()["name"],
|
||||
user_ids=[user["id"] for user in response.json()["users"]],
|
||||
@@ -44,11 +44,9 @@ class UserGroupManager:
|
||||
|
||||
@staticmethod
|
||||
def edit(
|
||||
user_group: TestUserGroup,
|
||||
user_performing_action: TestUser | None = None,
|
||||
user_group: DATestUserGroup,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
if not user_group.id:
|
||||
raise ValueError("User group has no ID")
|
||||
response = requests.patch(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}",
|
||||
json=user_group.model_dump(),
|
||||
@@ -59,14 +57,25 @@ class UserGroupManager:
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def set_curator_status(
|
||||
test_user_group: TestUserGroup,
|
||||
user_to_set_as_curator: TestUser,
|
||||
is_curator: bool = True,
|
||||
user_performing_action: TestUser | None = None,
|
||||
def delete(
|
||||
user_group: DATestUserGroup,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def set_curator_status(
|
||||
test_user_group: DATestUserGroup,
|
||||
user_to_set_as_curator: DATestUser,
|
||||
is_curator: bool = True,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
if not user_to_set_as_curator.id:
|
||||
raise ValueError("User has no ID")
|
||||
set_curator_request = {
|
||||
"user_id": user_to_set_as_curator.id,
|
||||
"is_curator": is_curator,
|
||||
@@ -82,7 +91,7 @@ class UserGroupManager:
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: TestUser | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> list[UserGroup]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group",
|
||||
@@ -95,9 +104,9 @@ class UserGroupManager:
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
user_group: TestUserGroup,
|
||||
user_group: DATestUserGroup,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: TestUser | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
all_user_groups = UserGroupManager.get_all(user_performing_action)
|
||||
for fetched_user_group in all_user_groups:
|
||||
@@ -120,8 +129,8 @@ class UserGroupManager:
|
||||
|
||||
@staticmethod
|
||||
def wait_for_sync(
|
||||
user_groups_to_check: list[TestUserGroup] | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
user_groups_to_check: list[DATestUserGroup] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
start = time.time()
|
||||
while True:
|
||||
@@ -130,7 +139,7 @@ class UserGroupManager:
|
||||
check_ids = {user_group.id for user_group in user_groups_to_check}
|
||||
user_group_ids = {user_group.id for user_group in user_groups}
|
||||
if not check_ids.issubset(user_group_ids):
|
||||
raise RuntimeError("Document set not found")
|
||||
raise RuntimeError("User group not found")
|
||||
user_groups = [
|
||||
user_group
|
||||
for user_group in user_groups
|
||||
@@ -146,3 +155,26 @@ class UserGroupManager:
|
||||
else:
|
||||
print("User groups were not synced yet, waiting...")
|
||||
time.sleep(2)
|
||||
|
||||
@staticmethod
|
||||
def wait_for_deletion_completion(
|
||||
user_groups_to_check: list[DATestUserGroup],
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
start = time.time()
|
||||
user_group_ids_to_check = {user_group.id for user_group in user_groups_to_check}
|
||||
while True:
|
||||
fetched_user_groups = UserGroupManager.get_all(user_performing_action)
|
||||
fetched_user_group_ids = {
|
||||
user_group.id for user_group in fetched_user_groups
|
||||
}
|
||||
if not user_group_ids_to_check.intersection(fetched_user_group_ids):
|
||||
return
|
||||
|
||||
if time.time() - start > MAX_DELAY:
|
||||
raise TimeoutError(
|
||||
f"User groups deletion was not completed within the {MAX_DELAY} seconds"
|
||||
)
|
||||
else:
|
||||
print("Some user groups are still being deleted, waiting...")
|
||||
time.sleep(2)
|
||||
|
||||
@@ -20,6 +20,9 @@ from danswer.document_index.vespa.index import VespaIndex
|
||||
from danswer.indexing.models import IndexingSetting
|
||||
from danswer.main import setup_postgres
|
||||
from danswer.main import setup_vespa
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _run_migrations(
|
||||
@@ -165,8 +168,8 @@ def reset_vespa() -> None:
|
||||
|
||||
def reset_all() -> None:
|
||||
"""Reset both Postgres and Vespa."""
|
||||
print("Resetting Postgres...")
|
||||
logger.info("Resetting Postgres...")
|
||||
reset_postgres()
|
||||
print("Resetting Vespa...")
|
||||
logger.info("Resetting Vespa...")
|
||||
reset_vespa()
|
||||
print("Finished resetting all.")
|
||||
logger.info("Finished resetting all.")
|
||||
|
||||
@@ -20,7 +20,7 @@ This means the flow is:
|
||||
"""
|
||||
|
||||
|
||||
class TestAPIKey(BaseModel):
|
||||
class DATestAPIKey(BaseModel):
|
||||
api_key_id: int
|
||||
api_key_display: str
|
||||
api_key: str | None = None # only present on initial creation
|
||||
@@ -31,14 +31,14 @@ class TestAPIKey(BaseModel):
|
||||
headers: dict
|
||||
|
||||
|
||||
class TestUser(BaseModel):
|
||||
class DATestUser(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
password: str
|
||||
headers: dict
|
||||
|
||||
|
||||
class TestCredential(BaseModel):
|
||||
class DATestCredential(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
credential_json: dict[str, Any]
|
||||
@@ -48,7 +48,7 @@ class TestCredential(BaseModel):
|
||||
groups: list[int]
|
||||
|
||||
|
||||
class TestConnector(BaseModel):
|
||||
class DATestConnector(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
source: DocumentSource
|
||||
@@ -63,7 +63,7 @@ class SimpleTestDocument(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class TestCCPair(BaseModel):
|
||||
class DATestCCPair(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
connector_id: int
|
||||
@@ -73,26 +73,26 @@ class TestCCPair(BaseModel):
|
||||
documents: list[SimpleTestDocument] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TestUserGroup(BaseModel):
|
||||
class DATestUserGroup(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
user_ids: list[str]
|
||||
cc_pair_ids: list[int]
|
||||
|
||||
|
||||
class TestLLMProvider(BaseModel):
|
||||
class DATestLLMProvider(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
provider: str
|
||||
api_key: str
|
||||
default_model_name: str
|
||||
is_public: bool
|
||||
groups: list[TestUserGroup]
|
||||
groups: list[int]
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
|
||||
|
||||
class TestDocumentSet(BaseModel):
|
||||
class DATestDocumentSet(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
@@ -103,7 +103,7 @@ class TestDocumentSet(BaseModel):
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TestPersona(BaseModel):
|
||||
class DATestPersona(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
@@ -122,13 +122,13 @@ class TestPersona(BaseModel):
|
||||
|
||||
|
||||
#
|
||||
class TestChatSession(BaseModel):
|
||||
class DATestChatSession(BaseModel):
|
||||
id: int
|
||||
persona_id: int
|
||||
description: str
|
||||
|
||||
|
||||
class TestChatMessage(BaseModel):
|
||||
class DATestChatMessage(BaseModel):
|
||||
id: str | None = None
|
||||
chat_session_id: int
|
||||
parent_message_id: str | None
|
||||
|
||||
@@ -3,7 +3,7 @@ import requests
|
||||
from danswer.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
|
||||
|
||||
|
||||
class TestVespaClient:
|
||||
class vespa_fixture:
|
||||
def __init__(self, index_name: str):
|
||||
self.index_name = index_name
|
||||
self.vespa_document_url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
|
||||
|
||||
@@ -7,7 +7,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from tests.integration.common_utils.reset import reset_all
|
||||
from tests.integration.common_utils.vespa import TestVespaClient
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
|
||||
def load_env_vars(env_file: str = ".env") -> None:
|
||||
@@ -36,9 +36,9 @@ def db_session() -> Generator[Session, None, None]:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vespa_client(db_session: Session) -> TestVespaClient:
|
||||
def vespa_client(db_session: Session) -> vespa_fixture:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
return TestVespaClient(index_name=search_settings.index_name)
|
||||
return vespa_fixture(index_name=search_settings.index_name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -22,17 +22,17 @@ from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.document_set import DocumentSetManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import TestAPIKey
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
from tests.integration.common_utils.test_models import TestUserGroup
|
||||
from tests.integration.common_utils.vespa import TestVespaClient
|
||||
from tests.integration.common_utils.test_models import DATestAPIKey
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.test_models import DATestUserGroup
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
|
||||
def test_connector_deletion(reset: None, vespa_client: TestVespaClient) -> None:
|
||||
def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: TestUser = UserManager.create(name="admin_user")
|
||||
# add api key to user
|
||||
api_key: TestAPIKey = APIKeyManager.create(
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
# create api key
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
@@ -76,11 +76,11 @@ def test_connector_deletion(reset: None, vespa_client: TestVespaClient) -> None:
|
||||
print("Document sets created and synced")
|
||||
|
||||
# create user groups
|
||||
user_group_1: TestUserGroup = UserGroupManager.create(
|
||||
user_group_1: DATestUserGroup = UserGroupManager.create(
|
||||
cc_pair_ids=[cc_pair_1.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
user_group_2: TestUserGroup = UserGroupManager.create(
|
||||
user_group_2: DATestUserGroup = UserGroupManager.create(
|
||||
cc_pair_ids=[cc_pair_1.id, cc_pair_2.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
@@ -174,15 +174,15 @@ def test_connector_deletion(reset: None, vespa_client: TestVespaClient) -> None:
|
||||
|
||||
|
||||
def test_connector_deletion_for_overlapping_connectors(
|
||||
reset: None, vespa_client: TestVespaClient
|
||||
reset: None, vespa_client: vespa_fixture
|
||||
) -> None:
|
||||
"""Checks to make sure that connectors with overlapping documents work properly. Specifically, that the overlapping
|
||||
document (1) still exists and (2) has the right document set / group post-deletion of one of the connectors.
|
||||
"""
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: TestUser = UserManager.create(name="admin_user")
|
||||
# add api key to user
|
||||
api_key: TestAPIKey = APIKeyManager.create(
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
# create api key
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
@@ -251,7 +251,7 @@ def test_connector_deletion_for_overlapping_connectors(
|
||||
)
|
||||
|
||||
# create a user group and attach it to connector 1
|
||||
user_group_1: TestUserGroup = UserGroupManager.create(
|
||||
user_group_1: DATestUserGroup = UserGroupManager.create(
|
||||
name="Test User Group 1",
|
||||
cc_pair_ids=[cc_pair_1.id],
|
||||
user_performing_action=admin_user,
|
||||
@@ -265,7 +265,7 @@ def test_connector_deletion_for_overlapping_connectors(
|
||||
print("User group 1 created and synced")
|
||||
|
||||
# create a user group and attach it to connector 2
|
||||
user_group_2: TestUserGroup = UserGroupManager.create(
|
||||
user_group_2: DATestUserGroup = UserGroupManager.create(
|
||||
name="Test User Group 2",
|
||||
cc_pair_ids=[cc_pair_2.id],
|
||||
user_performing_action=admin_user,
|
||||
|
||||
@@ -2,25 +2,25 @@ import requests
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.llm import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import TestAPIKey
|
||||
from tests.integration.common_utils.test_models import TestCCPair
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
from tests.integration.common_utils.test_models import DATestAPIKey
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_all_stream_chat_message_objects_outputs(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: TestUser = UserManager.create(name="admin_user")
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# create connector
|
||||
cc_pair_1: TestCCPair = CCPairManager.create_from_scratch(
|
||||
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
api_key: TestAPIKey = APIKeyManager.create(
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
@@ -3,25 +3,25 @@ import requests
|
||||
from danswer.configs.constants import MessageType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import NUM_DOCS
|
||||
from tests.integration.common_utils.llm import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import TestAPIKey
|
||||
from tests.integration.common_utils.test_models import TestCCPair
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
from tests.integration.common_utils.test_models import DATestAPIKey
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_send_message_simple_with_history(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: TestUser = UserManager.create(name="admin_user")
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# create connectors
|
||||
cc_pair_1: TestCCPair = CCPairManager.create_from_scratch(
|
||||
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
api_key: TestAPIKey = APIKeyManager.create(
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
@@ -64,13 +64,13 @@ def test_send_message_simple_with_history(reset: None) -> None:
|
||||
|
||||
def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: TestUser = UserManager.create(name="admin_user")
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# create connector
|
||||
cc_pair_1: TestCCPair = CCPairManager.create_from_scratch(
|
||||
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
api_key: TestAPIKey = APIKeyManager.create(
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
@@ -5,19 +5,19 @@ from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.document_set import DocumentSetManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import TestAPIKey
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
from tests.integration.common_utils.vespa import TestVespaClient
|
||||
from tests.integration.common_utils.test_models import DATestAPIKey
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
|
||||
def test_multiple_document_sets_syncing_same_connnector(
|
||||
reset: None, vespa_client: TestVespaClient
|
||||
reset: None, vespa_client: vespa_fixture
|
||||
) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: TestUser = UserManager.create(name="admin_user")
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# add api key to user
|
||||
api_key: TestAPIKey = APIKeyManager.create(
|
||||
# create api key
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
@@ -66,12 +66,12 @@ def test_multiple_document_sets_syncing_same_connnector(
|
||||
)
|
||||
|
||||
|
||||
def test_removing_connector(reset: None, vespa_client: TestVespaClient) -> None:
|
||||
def test_removing_connector(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: TestUser = UserManager.create(name="admin_user")
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# add api key to user
|
||||
api_key: TestAPIKey = APIKeyManager.create(
|
||||
# create api key
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
|
||||
@@ -10,17 +10,17 @@ from danswer.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
from tests.integration.common_utils.managers.user import TestUser
|
||||
from tests.integration.common_utils.managers.user import DATestUser
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
def test_cc_pair_permissions(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: TestUser = UserManager.create(name="admin_user")
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Creating a curator
|
||||
curator: TestUser = UserManager.create(name="curator")
|
||||
curator: DATestUser = UserManager.create(name="curator")
|
||||
|
||||
# Creating a user group
|
||||
user_group_1 = UserGroupManager.create(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user