Compare commits

..

1 Commits

Author SHA1 Message Date
pablodanswer
14f57d6475 remove endpoint 2024-10-31 12:07:47 -07:00
823 changed files with 27323 additions and 45698 deletions

View File

@@ -3,61 +3,61 @@ name: Build and Push Backend Image on Tag
on:
push:
tags:
- "*"
- '*'
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'danswer/danswer-backend-cloud' || 'danswer/danswer-backend' }}
REGISTRY_IMAGE: danswer/danswer-backend
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
# TODO: investigate a matrix build like the web container
# TODO: investigate a matrix build like the web container
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Install build-essential
run: |
sudo apt-get update
sudo apt-get install -y build-essential
- name: Install build-essential
run: |
sudo apt-get update
sudo apt-get install -y build-essential
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
with:
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
trivyignores: ./backend/.trivyignore
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'
trivyignores: ./backend/.trivyignore

View File

@@ -4,12 +4,12 @@ name: Build and Push Cloud Web Image on Tag
on:
push:
tags:
- "*"
- '*'
env:
REGISTRY_IMAGE: danswer/danswer-web-server-cloud
REGISTRY_IMAGE: danswer/danswer-cloud-web-server
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build:
runs-on:
@@ -28,11 +28,11 @@ jobs:
- name: Prepare
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
@@ -41,16 +41,16 @@ jobs:
tags: |
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push by digest
id: build
uses: docker/build-push-action@v5
@@ -65,18 +65,17 @@ jobs:
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_GTM_ENABLED=true
# needed due to weird interactions with the builds for different platforms
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
- name: Export digest
run: |
mkdir -p /tmp/digests
digest="${{ steps.build.outputs.digest }}"
touch "/tmp/digests/${digest#sha256:}"
touch "/tmp/digests/${digest#sha256:}"
- name: Upload digest
uses: actions/upload-artifact@v4
with:
@@ -96,42 +95,42 @@ jobs:
path: /tmp/digests
pattern: digests-*
merge-multiple: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Create manifest list and push
working-directory: /tmp/digests
run: |
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
- name: Inspect image
run: |
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
severity: 'CRITICAL,HIGH'

View File

@@ -3,53 +3,53 @@ name: Build and Push Model Server Image on Tag
on:
push:
tags:
- "*"
- '*'
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'danswer/danswer-model-server-cloud' || 'danswer/danswer-model-server' }}
REGISTRY_IMAGE: danswer/danswer-model-server
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Model Server Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
- name: Model Server Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
with:
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'

View File

@@ -1,76 +0,0 @@
# Scan for problematic software licenses
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
name: 'Nightly - Scan licenses'
on:
# schedule:
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
workflow_dispatch: # Allows manual triggering
permissions:
actions: read
contents: read
security-events: write
jobs:
scan-licenses:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- name: Get explicit and transitive dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
pip freeze > requirements-all.txt
- name: Check python
id: license_check_report
uses: pilosus/action-pip-license-checker@v2
with:
requirements: 'requirements-all.txt'
fail: 'Copyleft'
exclude: '(?i)^(pylint|aio[-_]*).*'
- name: Print report
if: ${{ always() }}
run: echo "${{ steps.license_check_report.outputs.report }}"
- name: Install npm dependencies
working-directory: ./web
run: npm ci
- name: Run Trivy vulnerability scanner in repo mode
uses: aquasecurity/trivy-action@0.28.0
with:
scan-type: fs
scanners: license
format: table
# format: sarif
# output: trivy-results.sarif
severity: HIGH,CRITICAL
# - name: Upload Trivy scan results to GitHub Security tab
# uses: github/codeql-action/upload-sarif@v3
# with:
# sarif_file: trivy-results.sarif

View File

@@ -13,10 +13,7 @@ on:
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
integration-tests:
# See https://runs-on.com/runners/linux/
@@ -198,13 +195,9 @@ jobs:
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
danswer/danswer-integration:test \
/app/tests/integration/tests \
/app/tests/integration/connector_job_tests
/app/tests/integration/tests
continue-on-error: true
id: run_tests
@@ -217,18 +210,17 @@ jobs:
echo "All integration tests passed successfully."
fi
# save before stopping the containers so the logs can be captured
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
- name: Save Docker logs
if: success() || failure()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
- name: Upload logs
if: success() || failure()

View File

@@ -1,225 +0,0 @@
name: Run Chromatic Tests
concurrency:
group: Run-Chromatic-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on: push
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
jobs:
playwright-tests:
name: Playwright Tests
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
- name: Setup node
uses: actions/setup-node@v4
with:
node-version: 22
- name: Install node dependencies
working-directory: ./web
run: npm ci
- name: Install playwright browsers
working-directory: ./web
run: npx playwright install --with-deps
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# we use the runs-on cache for docker builds
# in conjunction with runs-on runners, it has better speed and unlimited caching
# https://runs-on.com/caching/s3-cache-for-github-actions/
# https://runs-on.com/caching/docker/
# https://github.com/moby/buildkit#s3-cache-experimental
# images are built and run locally for testing purposes. Not pushed.
- name: Build Web Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./web
file: ./web/Dockerfile
platforms: linux/amd64
tags: danswer/danswer-web-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Backend Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64
tags: danswer/danswer-backend:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Model Server Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64
tags: danswer/danswer-model-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
docker logs -f danswer-stack-api_server-1 &
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
echo "Finished waiting for service."
- name: Run pytest playwright test init
working-directory: ./backend
env:
PYTEST_IGNORE_SKIP: true
run: pytest -s tests/integration/tests/playwright/test_playwright.py
- name: Run Playwright tests
working-directory: ./web
run: npx playwright test
- uses: actions/upload-artifact@v4
if: always()
with:
# Chromatic automatically defaults to the test-results directory.
# Replace with the path to your custom directory and adjust the CHROMATIC_ARCHIVE_LOCATION environment variable accordingly.
name: test-results
path: ./web/test-results
retention-days: 30
# save before stopping the containers so the logs can be captured
- name: Save Docker logs
if: success() || failure()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@v4
with:
name: docker-logs
path: ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
chromatic-tests:
name: Chromatic Tests
needs: playwright-tests
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup node
uses: actions/setup-node@v4
with:
node-version: 22
- name: Install node dependencies
working-directory: ./web
run: npm ci
- name: Download Playwright test results
uses: actions/download-artifact@v4
with:
name: test-results
path: ./web/test-results
- name: Run Chromatic
uses: chromaui/action@latest
with:
playwright: true
projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
workingDir: ./web
env:
CHROMATIC_ARCHIVE_LOCATION: ./test-results

View File

@@ -1,72 +0,0 @@
name: Helm - Lint and Test Charts
on:
merge_group:
pull_request:
branches: [ main ]
workflow_dispatch: # Allows manual triggering
jobs:
helm-chart-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
# fetch-depth 0 is required for helm/chart-testing-action
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Helm
uses: azure/setup-helm@v4.2.0
with:
version: v3.14.4
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.6.1
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
- name: Run chart-testing (list-changed)
id: list-changed
run: |
echo "default_branch: ${{ github.event.repository.default_branch }}"
changed=$(ct list-changed --remote origin --target-branch ${{ github.event.repository.default_branch }} --chart-dirs deployment/helm/charts)
echo "list-changed output: $changed"
if [[ -n "$changed" ]]; then
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
# rkuo: I don't think we need python?
# - name: Set up Python
# uses: actions/setup-python@v5
# with:
# python-version: '3.11'
# cache: 'pip'
# cache-dependency-path: |
# backend/requirements/default.txt
# backend/requirements/dev.txt
# backend/requirements/model_server.txt
# - run: |
# python -m pip install --upgrade pip
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
# lint all charts if any changes were detected
- name: Run chart-testing (lint)
if: steps.list-changed.outputs.changed == 'true'
run: ct lint --config ct.yaml --all
# the following would lint only changed charts, but linting isn't expensive
# run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
- name: Create kind cluster
if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@v1.10.0
- name: Run chart-testing (install)
if: steps.list-changed.outputs.changed == 'true'
run: ct install --all --helm-extra-set-args="--set=nginx.enabled=false" --debug --config ct.yaml
# the following would install only changed charts, but we only have one chart so
# don't worry about that for now
# run: ct install --target-branch ${{ github.event.repository.default_branch }}

View File

@@ -0,0 +1,68 @@
# This workflow is intentionally disabled while we're still working on it
# It's close to ready, but a race condition needs to be fixed with
# API server and Vespa startup, and it needs to have a way to build/test against
# local containers
name: Helm - Lint and Test Charts
on:
merge_group:
pull_request:
branches: [ main ]
jobs:
lint-test:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
# fetch-depth 0 is required for helm/chart-testing-action
steps:
- name: Checkout code
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Set up Helm
uses: azure/setup-helm@v4.2.0
with:
version: v3.14.4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.6.1
- name: Run chart-testing (list-changed)
id: list-changed
run: |
changed=$(ct list-changed --target-branch ${{ github.event.repository.default_branch }})
if [[ -n "$changed" ]]; then
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
- name: Run chart-testing (lint)
# if: steps.list-changed.outputs.changed == 'true'
run: ct lint --all --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
- name: Create kind cluster
# if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@v1.10.0
- name: Run chart-testing (install)
# if: steps.list-changed.outputs.changed == 'true'
run: ct install --all --config ct.yaml
# run: ct install --target-branch ${{ github.event.repository.default_branch }}

View File

@@ -18,14 +18,6 @@ env:
# Jira
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
# Google
GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1 }}
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
# Slab
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
jobs:
connectors-check:

View File

@@ -15,7 +15,7 @@ env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
jobs:
model-check:
connectors-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]

1
.gitignore vendored
View File

@@ -7,4 +7,3 @@
.vscode/
*.sw?
/backend/tests/regression/answer_quality/search_test_config.yaml
/web/test-results/

View File

@@ -203,7 +203,7 @@
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
"vespa_metadata_sync,connector_deletion",
],
"presentation": {
"group": "2",
@@ -232,7 +232,7 @@
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
"connector_pruning",
],
"presentation": {
"group": "2",

View File

@@ -32,7 +32,7 @@ To contribute to this project, please follow the
When opening a pull request, mention related issues and feel free to tag relevant maintainers.
Before creating a pull request please make sure that the new changes conform to the formatting and linting requirements.
See the [Formatting and Linting](#formatting-and-linting) section for how to run these checks locally.
See the [Formatting and Linting](#-formatting-and-linting) section for how to run these checks locally.
### Getting Help 🙋

View File

@@ -1,5 +1,4 @@
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/README.md"} -->
<a name="readme-top"></a>
<h2 align="center">
<a href="https://www.danswer.ai/"> <img width="50%" src="https://github.com/danswer-owners/danswer/blob/1fabd9372d66cd54238847197c33f091a724803b/DanswerWithName.png?raw=true)" /></a>
@@ -12,7 +11,7 @@
<a href="https://docs.danswer.dev/" target="_blank">
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
</a>
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA" target="_blank">
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ" target="_blank">
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
</a>
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
@@ -128,19 +127,3 @@ To try the Danswer Enterprise Edition:
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
## ⭐Star History
[![Star History Chart](https://api.star-history.com/svg?repos=danswer-ai/danswer&type=Date)](https://star-history.com/#danswer-ai/danswer&Date)
## ✨Contributors
<a href="https://github.com/danswer-ai/danswer/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=danswer-ai/danswer"/>
</a>
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
↑ Back to Top ↑
</a>
</p>

View File

@@ -12,6 +12,7 @@ ARG DANSWER_VERSION=0.8-dev
ENV DANSWER_VERSION=${DANSWER_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true"
ARG CA_CERT_CONTENT=""
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
# Install system dependencies
@@ -38,6 +39,15 @@ RUN apt-get update && \
apt-get clean
# Conditionally write the CA certificate and update certificates
RUN if [ -n "$CA_CERT_CONTENT" ]; then \
echo "Adding custom CA certificate"; \
echo "$CA_CERT_CONTENT" > /usr/local/share/ca-certificates/my-ca.crt && \
chmod 644 /usr/local/share/ca-certificates/my-ca.crt && \
update-ca-certificates; \
else \
echo "No custom CA certificate provided"; \
fi
# Install Python dependencies
# Remove py which is pulled in by retry, py is not needed and is a CVE
@@ -73,11 +83,11 @@ RUN apt-get update && \
rm -rf /var/lib/apt/lists/* && \
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
# Pre-downloading models for setups with limited egress
RUN python -c "from tokenizers import Tokenizer; \
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
# Pre-downloading NLTK for setups with limited egress
RUN python -c "import nltk; \
nltk.download('stopwords', quiet=True); \

View File

@@ -1,5 +1,5 @@
from sqlalchemy.engine.base import Connection
from typing import Literal
from typing import Any
import asyncio
from logging.config import fileConfig
import logging
@@ -8,7 +8,6 @@ from alembic import context
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.sql import text
from sqlalchemy.sql.schema import SchemaItem
from shared_configs.configs import MULTI_TENANT
from danswer.db.engine import build_connection_string
@@ -36,18 +35,7 @@ logger = logging.getLogger(__name__)
def include_object(
object: SchemaItem,
name: str | None,
type_: Literal[
"schema",
"table",
"column",
"index",
"unique_constraint",
"foreign_key_constraint",
],
reflected: bool,
compare_to: SchemaItem | None,
object: Any, name: str, type_: str, reflected: bool, compare_to: Any
) -> bool:
"""
Determines whether a database object should be included in migrations.

View File

@@ -1,59 +0,0 @@
"""display custom llm models
Revision ID: 177de57c21c9
Revises: 4ee1287bd26a
Create Date: 2024-11-21 11:49:04.488677
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy import and_
revision = "177de57c21c9"
down_revision = "4ee1287bd26a"
branch_labels = None
depends_on = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
llm_provider = sa.table(
"llm_provider",
sa.column("id", sa.Integer),
sa.column("provider", sa.String),
sa.column("model_names", postgresql.ARRAY(sa.String)),
sa.column("display_model_names", postgresql.ARRAY(sa.String)),
)
excluded_providers = ["openai", "bedrock", "anthropic", "azure"]
providers_to_update = sa.select(
llm_provider.c.id,
llm_provider.c.model_names,
llm_provider.c.display_model_names,
).where(
and_(
~llm_provider.c.provider.in_(excluded_providers),
llm_provider.c.model_names.isnot(None),
)
)
results = conn.execute(providers_to_update).fetchall()
for provider_id, model_names, display_model_names in results:
if display_model_names is None:
display_model_names = []
combined_model_names = list(set(display_model_names + model_names))
update_stmt = (
llm_provider.update()
.where(llm_provider.c.id == provider_id)
.values(display_model_names=combined_model_names)
)
conn.execute(update_stmt)
def downgrade() -> None:
pass

View File

@@ -1,68 +0,0 @@
"""default chosen assistants to none
Revision ID: 26b931506ecb
Revises: 2daa494a0851
Create Date: 2024-11-12 13:23:29.858995
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "26b931506ecb"
down_revision = "2daa494a0851"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user", sa.Column("chosen_assistants_new", postgresql.JSONB(), nullable=True)
)
op.execute(
"""
UPDATE "user"
SET chosen_assistants_new =
CASE
WHEN chosen_assistants = '[-2, -1, 0]' THEN NULL
ELSE chosen_assistants
END
"""
)
op.drop_column("user", "chosen_assistants")
op.alter_column(
"user", "chosen_assistants_new", new_column_name="chosen_assistants"
)
def downgrade() -> None:
op.add_column(
"user",
sa.Column(
"chosen_assistants_old",
postgresql.JSONB(),
nullable=False,
server_default="[-2, -1, 0]",
),
)
op.execute(
"""
UPDATE "user"
SET chosen_assistants_old =
CASE
WHEN chosen_assistants IS NULL THEN '[-2, -1, 0]'::jsonb
ELSE chosen_assistants
END
"""
)
op.drop_column("user", "chosen_assistants")
op.alter_column(
"user", "chosen_assistants_old", new_column_name="chosen_assistants"
)

View File

@@ -1,30 +0,0 @@
"""add-group-sync-time
Revision ID: 2daa494a0851
Revises: c0fd6e4da83a
Create Date: 2024-11-11 10:57:22.991157
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2daa494a0851"
down_revision = "c0fd6e4da83a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"last_time_external_group_sync",
sa.DateTime(timezone=True),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "last_time_external_group_sync")

View File

@@ -1,50 +0,0 @@
"""single tool call per message
Revision ID: 33cb72ea4d80
Revises: 5b29123cd710
Create Date: 2024-11-01 12:51:01.535003
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "33cb72ea4d80"
down_revision = "5b29123cd710"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Step 1: Delete extraneous ToolCall entries
# Keep only the ToolCall with the smallest 'id' for each 'message_id'
op.execute(
sa.text(
"""
DELETE FROM tool_call
WHERE id NOT IN (
SELECT MIN(id)
FROM tool_call
WHERE message_id IS NOT NULL
GROUP BY message_id
);
"""
)
)
# Step 2: Add a unique constraint on message_id
op.create_unique_constraint(
constraint_name="uq_tool_call_message_id",
table_name="tool_call",
columns=["message_id"],
)
def downgrade() -> None:
# Step 1: Drop the unique constraint on message_id
op.drop_constraint(
constraint_name="uq_tool_call_message_id",
table_name="tool_call",
type_="unique",
)

View File

@@ -1,45 +0,0 @@
"""add persona categories
Revision ID: 47e5bef3a1d7
Revises: dfbe9e93d3c7
Create Date: 2024-11-05 18:55:02.221064
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "47e5bef3a1d7"
down_revision = "dfbe9e93d3c7"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create the persona_category table
op.create_table(
"persona_category",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("description", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
# Add category_id to persona table
op.add_column("persona", sa.Column("category_id", sa.Integer(), nullable=True))
op.create_foreign_key(
"fk_persona_category",
"persona",
"persona_category",
["category_id"],
["id"],
ondelete="SET NULL",
)
def downgrade() -> None:
op.drop_constraint("fk_persona_category", "persona", type_="foreignkey")
op.drop_column("persona", "category_id")
op.drop_table("persona_category")

View File

@@ -1,280 +0,0 @@
"""add_multiple_slack_bot_support
Revision ID: 4ee1287bd26a
Revises: 47e5bef3a1d7
Create Date: 2024-11-06 13:15:53.302644
"""
import logging
from typing import cast
from alembic import op
import sqlalchemy as sa
from sqlalchemy.orm import Session
from danswer.key_value_store.factory import get_kv_store
from danswer.db.models import SlackBot
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "4ee1287bd26a"
down_revision = "47e5bef3a1d7"
branch_labels: None = None
depends_on: None = None
# Configure logging
logger = logging.getLogger("alembic.runtime.migration")
logger.setLevel(logging.INFO)
def upgrade() -> None:
logger.info(f"{revision}: create_table: slack_bot")
# Create new slack_bot table
op.create_table(
"slack_bot",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("enabled", sa.Boolean(), nullable=False, server_default="true"),
sa.Column("bot_token", sa.LargeBinary(), nullable=False),
sa.Column("app_token", sa.LargeBinary(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("bot_token"),
sa.UniqueConstraint("app_token"),
)
# # Create new slack_channel_config table
op.create_table(
"slack_channel_config",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("slack_bot_id", sa.Integer(), nullable=True),
sa.Column("persona_id", sa.Integer(), nullable=True),
sa.Column("channel_config", postgresql.JSONB(), nullable=False),
sa.Column("response_type", sa.String(), nullable=False),
sa.Column(
"enable_auto_filters", sa.Boolean(), nullable=False, server_default="false"
),
sa.ForeignKeyConstraint(
["slack_bot_id"],
["slack_bot.id"],
),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# Handle existing Slack bot tokens first
logger.info(f"{revision}: Checking for existing Slack bot.")
bot_token = None
app_token = None
first_row_id = None
try:
tokens = cast(dict, get_kv_store().load("slack_bot_tokens_config_key"))
except Exception:
logger.warning("No existing Slack bot tokens found.")
tokens = {}
bot_token = tokens.get("bot_token")
app_token = tokens.get("app_token")
if bot_token and app_token:
logger.info(f"{revision}: Found bot and app tokens.")
session = Session(bind=op.get_bind())
new_slack_bot = SlackBot(
name="Slack Bot (Migrated)",
enabled=True,
bot_token=bot_token,
app_token=app_token,
)
session.add(new_slack_bot)
session.commit()
first_row_id = new_slack_bot.id
# Create a default bot if none exists
# This is in case there are no slack tokens but there are channels configured
op.execute(
sa.text(
"""
INSERT INTO slack_bot (name, enabled, bot_token, app_token)
SELECT 'Default Bot', true, '', ''
WHERE NOT EXISTS (SELECT 1 FROM slack_bot)
RETURNING id;
"""
)
)
# Get the bot ID to use (either from existing migration or newly created)
bot_id_query = sa.text(
"""
SELECT COALESCE(
:first_row_id,
(SELECT id FROM slack_bot ORDER BY id ASC LIMIT 1)
) as bot_id;
"""
)
result = op.get_bind().execute(bot_id_query, {"first_row_id": first_row_id})
bot_id = result.scalar()
# CTE (Common Table Expression) that transforms the old slack_bot_config table data
# This splits up the channel_names into their own rows
channel_names_cte = """
WITH channel_names AS (
SELECT
sbc.id as config_id,
sbc.persona_id,
sbc.response_type,
sbc.enable_auto_filters,
jsonb_array_elements_text(sbc.channel_config->'channel_names') as channel_name,
sbc.channel_config->>'respond_tag_only' as respond_tag_only,
sbc.channel_config->>'respond_to_bots' as respond_to_bots,
sbc.channel_config->'respond_member_group_list' as respond_member_group_list,
sbc.channel_config->'answer_filters' as answer_filters,
sbc.channel_config->'follow_up_tags' as follow_up_tags
FROM slack_bot_config sbc
)
"""
# Insert the channel names into the new slack_channel_config table
insert_statement = """
INSERT INTO slack_channel_config (
slack_bot_id,
persona_id,
channel_config,
response_type,
enable_auto_filters
)
SELECT
:bot_id,
channel_name.persona_id,
jsonb_build_object(
'channel_name', channel_name.channel_name,
'respond_tag_only',
COALESCE((channel_name.respond_tag_only)::boolean, false),
'respond_to_bots',
COALESCE((channel_name.respond_to_bots)::boolean, false),
'respond_member_group_list',
COALESCE(channel_name.respond_member_group_list, '[]'::jsonb),
'answer_filters',
COALESCE(channel_name.answer_filters, '[]'::jsonb),
'follow_up_tags',
COALESCE(channel_name.follow_up_tags, '[]'::jsonb)
),
channel_name.response_type,
channel_name.enable_auto_filters
FROM channel_names channel_name;
"""
op.execute(sa.text(channel_names_cte + insert_statement).bindparams(bot_id=bot_id))
# Clean up old tokens if they existed
try:
if bot_token and app_token:
logger.info(f"{revision}: Removing old bot and app tokens.")
get_kv_store().delete("slack_bot_tokens_config_key")
except Exception:
logger.warning("tried to delete tokens in dynamic config but failed")
# Rename the table
op.rename_table(
"slack_bot_config__standard_answer_category",
"slack_channel_config__standard_answer_category",
)
# Rename the column
op.alter_column(
"slack_channel_config__standard_answer_category",
"slack_bot_config_id",
new_column_name="slack_channel_config_id",
)
# Drop the table with CASCADE to handle dependent objects
op.execute("DROP TABLE slack_bot_config CASCADE")
logger.info(f"{revision}: Migration complete.")
def downgrade() -> None:
# Recreate the old slack_bot_config table
op.create_table(
"slack_bot_config",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("persona_id", sa.Integer(), nullable=True),
sa.Column("channel_config", postgresql.JSONB(), nullable=False),
sa.Column("response_type", sa.String(), nullable=False),
sa.Column("enable_auto_filters", sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# Migrate data back to the old format
# Group by persona_id to combine channel names back into arrays
op.execute(
sa.text(
"""
INSERT INTO slack_bot_config (
persona_id,
channel_config,
response_type,
enable_auto_filters
)
SELECT DISTINCT ON (persona_id)
persona_id,
jsonb_build_object(
'channel_names', (
SELECT jsonb_agg(c.channel_config->>'channel_name')
FROM slack_channel_config c
WHERE c.persona_id = scc.persona_id
),
'respond_tag_only', (channel_config->>'respond_tag_only')::boolean,
'respond_to_bots', (channel_config->>'respond_to_bots')::boolean,
'respond_member_group_list', channel_config->'respond_member_group_list',
'answer_filters', channel_config->'answer_filters',
'follow_up_tags', channel_config->'follow_up_tags'
),
response_type,
enable_auto_filters
FROM slack_channel_config scc
WHERE persona_id IS NOT NULL;
"""
)
)
# Rename the table back
op.rename_table(
"slack_channel_config__standard_answer_category",
"slack_bot_config__standard_answer_category",
)
# Rename the column back
op.alter_column(
"slack_bot_config__standard_answer_category",
"slack_channel_config_id",
new_column_name="slack_bot_config_id",
)
# Try to save the first bot's tokens back to KV store
try:
first_bot = (
op.get_bind()
.execute(
sa.text(
"SELECT bot_token, app_token FROM slack_bot ORDER BY id LIMIT 1"
)
)
.first()
)
if first_bot and first_bot.bot_token and first_bot.app_token:
tokens = {
"bot_token": first_bot.bot_token,
"app_token": first_bot.app_token,
}
get_kv_store().store("slack_bot_tokens_config_key", tokens)
except Exception:
logger.warning("Failed to save tokens back to KV store")
# Drop the new tables in reverse order
op.drop_table("slack_channel_config")
op.drop_table("slack_bot")

View File

@@ -1,70 +0,0 @@
"""nullable search settings for historic index attempts
Revision ID: 5b29123cd710
Revises: 949b4a92a401
Create Date: 2024-10-30 19:37:59.630704
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "5b29123cd710"
down_revision = "949b4a92a401"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop the existing foreign key constraint
op.drop_constraint(
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
)
# Modify the column to be nullable
op.alter_column(
"index_attempt", "search_settings_id", existing_type=sa.INTEGER(), nullable=True
)
# Add back the foreign key with ON DELETE SET NULL
op.create_foreign_key(
"fk_index_attempt_search_settings",
"index_attempt",
"search_settings",
["search_settings_id"],
["id"],
ondelete="SET NULL",
)
def downgrade() -> None:
# Warning: This will delete all index attempts that don't have search settings
op.execute(
"""
DELETE FROM index_attempt
WHERE search_settings_id IS NULL
"""
)
# Drop foreign key constraint
op.drop_constraint(
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
)
# Modify the column to be not nullable
op.alter_column(
"index_attempt",
"search_settings_id",
existing_type=sa.INTEGER(),
nullable=False,
)
# Add back the foreign key without ON DELETE SET NULL
op.create_foreign_key(
"fk_index_attempt_search_settings",
"index_attempt",
"search_settings",
["search_settings_id"],
["id"],
)

View File

@@ -1,45 +0,0 @@
"""remove default bot
Revision ID: 6d562f86c78b
Revises: 177de57c21c9
Create Date: 2024-11-22 11:51:29.331336
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "6d562f86c78b"
down_revision = "177de57c21c9"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
sa.text(
"""
DELETE FROM slack_bot
WHERE name = 'Default Bot'
AND bot_token = ''
AND app_token = ''
AND NOT EXISTS (
SELECT 1 FROM slack_channel_config
WHERE slack_channel_config.slack_bot_id = slack_bot.id
)
"""
)
)
def downgrade() -> None:
op.execute(
sa.text(
"""
INSERT INTO slack_bot (name, enabled, bot_token, app_token)
SELECT 'Default Bot', true, '', ''
WHERE NOT EXISTS (SELECT 1 FROM slack_bot)
RETURNING id;
"""
)
)

View File

@@ -9,8 +9,8 @@ from alembic import op
import sqlalchemy as sa
from danswer.db.models import IndexModelStatus
from danswer.context.search.enums import RecencyBiasSetting
from danswer.context.search.enums import SearchType
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType
# revision identifiers, used by Alembic.
revision = "776b3bbe9092"

View File

@@ -1,35 +0,0 @@
"""add web ui option to slack config
Revision ID: 93560ba1b118
Revises: 6d562f86c78b
Create Date: 2024-11-24 06:36:17.490612
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "93560ba1b118"
down_revision = "6d562f86c78b"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add show_continue_in_web_ui with default False to all existing channel_configs
op.execute(
"""
UPDATE slack_channel_config
SET channel_config = channel_config || '{"show_continue_in_web_ui": false}'::jsonb
WHERE NOT channel_config ? 'show_continue_in_web_ui'
"""
)
def downgrade() -> None:
# Remove show_continue_in_web_ui from all channel_configs
op.execute(
"""
UPDATE slack_channel_config
SET channel_config = channel_config - 'show_continue_in_web_ui'
"""
)

View File

@@ -7,7 +7,6 @@ Create Date: 2024-10-26 13:06:06.937969
"""
from alembic import op
from sqlalchemy.orm import Session
from sqlalchemy import text
# Import your models and constants
from danswer.db.models import (
@@ -16,6 +15,7 @@ from danswer.db.models import (
Credential,
IndexAttempt,
)
from danswer.configs.constants import DocumentSource
# revision identifiers, used by Alembic.
@@ -30,11 +30,13 @@ def upgrade() -> None:
bind = op.get_bind()
session = Session(bind=bind)
# Get connectors using raw SQL
result = bind.execute(
text("SELECT id FROM connector WHERE source = 'requesttracker'")
connectors_to_delete = (
session.query(Connector)
.filter(Connector.source == DocumentSource.REQUESTTRACKER)
.all()
)
connector_ids = [row[0] for row in result]
connector_ids = [connector.id for connector in connectors_to_delete]
if connector_ids:
cc_pairs_to_delete = (

View File

@@ -1,30 +0,0 @@
"""add creator to cc pair
Revision ID: 9cf5c00f72fe
Revises: 26b931506ecb
Create Date: 2024-11-12 15:16:42.682902
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9cf5c00f72fe"
down_revision = "26b931506ecb"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"creator_id",
sa.UUID(as_uuid=True),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "creator_id")

View File

@@ -1,36 +0,0 @@
"""Combine Search and Chat
Revision ID: 9f696734098f
Revises: a8c2065484e6
Create Date: 2024-11-27 15:32:19.694972
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9f696734098f"
down_revision = "a8c2065484e6"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column("chat_session", "description", nullable=True)
op.drop_column("chat_session", "one_shot")
op.drop_column("slack_channel_config", "response_type")
def downgrade() -> None:
op.execute("UPDATE chat_session SET description = '' WHERE description IS NULL")
op.alter_column("chat_session", "description", nullable=False)
op.add_column(
"chat_session",
sa.Column("one_shot", sa.Boolean(), nullable=False, server_default=sa.false()),
)
op.add_column(
"slack_channel_config",
sa.Column(
"response_type", sa.String(), nullable=False, server_default="citations"
),
)

View File

@@ -1,27 +0,0 @@
"""add auto scroll to user model
Revision ID: a8c2065484e6
Revises: abe7378b8217
Create Date: 2024-11-22 17:34:09.690295
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a8c2065484e6"
down_revision = "abe7378b8217"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column("auto_scroll", sa.Boolean(), nullable=True, server_default=None),
)
def downgrade() -> None:
op.drop_column("user", "auto_scroll")

View File

@@ -1,30 +0,0 @@
"""add indexing trigger to cc_pair
Revision ID: abe7378b8217
Revises: 6d562f86c78b
Create Date: 2024-11-26 19:09:53.481171
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "abe7378b8217"
down_revision = "93560ba1b118"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"indexing_trigger",
sa.Enum("UPDATE", "REINDEX", name="indexingmode", native_enum=False),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "indexing_trigger")

View File

@@ -288,15 +288,6 @@ def upgrade() -> None:
def downgrade() -> None:
# NOTE: you will lose all chat history. This is to satisfy the non-nullable constraints
# below
op.execute("DELETE FROM chat_feedback")
op.execute("DELETE FROM chat_message__search_doc")
op.execute("DELETE FROM document_retrieval_feedback")
op.execute("DELETE FROM document_retrieval_feedback")
op.execute("DELETE FROM chat_message")
op.execute("DELETE FROM chat_session")
op.drop_constraint(
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
)

View File

@@ -1,48 +0,0 @@
"""remove description from starter messages
Revision ID: b72ed7a5db0e
Revises: 33cb72ea4d80
Create Date: 2024-11-03 15:55:28.944408
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "b72ed7a5db0e"
down_revision = "33cb72ea4d80"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
sa.text(
"""
UPDATE persona
SET starter_messages = (
SELECT jsonb_agg(elem - 'description')
FROM jsonb_array_elements(starter_messages) elem
)
WHERE starter_messages IS NOT NULL
AND jsonb_typeof(starter_messages) = 'array'
"""
)
)
def downgrade() -> None:
op.execute(
sa.text(
"""
UPDATE persona
SET starter_messages = (
SELECT jsonb_agg(elem || '{"description": ""}')
FROM jsonb_array_elements(starter_messages) elem
)
WHERE starter_messages IS NOT NULL
AND jsonb_typeof(starter_messages) = 'array'
"""
)
)

View File

@@ -1,29 +0,0 @@
"""add recent assistants
Revision ID: c0fd6e4da83a
Revises: b72ed7a5db0e
Create Date: 2024-11-03 17:28:54.916618
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "c0fd6e4da83a"
down_revision = "b72ed7a5db0e"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column(
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
),
)
def downgrade() -> None:
op.drop_column("user", "recent_assistants")

View File

@@ -23,56 +23,6 @@ def upgrade() -> None:
def downgrade() -> None:
# Delete chat messages and feedback first since they reference chat sessions
# Get chat messages from sessions with null persona_id
chat_messages_query = """
SELECT id
FROM chat_message
WHERE chat_session_id IN (
SELECT id
FROM chat_session
WHERE persona_id IS NULL
)
"""
# Delete dependent records first
op.execute(
f"""
DELETE FROM document_retrieval_feedback
WHERE chat_message_id IN (
{chat_messages_query}
)
"""
)
op.execute(
f"""
DELETE FROM chat_message__search_doc
WHERE chat_message_id IN (
{chat_messages_query}
)
"""
)
# Delete chat messages
op.execute(
"""
DELETE FROM chat_message
WHERE chat_session_id IN (
SELECT id
FROM chat_session
WHERE persona_id IS NULL
)
"""
)
# Now we can safely delete the chat sessions
op.execute(
"""
DELETE FROM chat_session
WHERE persona_id IS NULL
"""
)
op.alter_column(
"chat_session",
"persona_id",

View File

@@ -1,42 +0,0 @@
"""extended_role_for_non_web
Revision ID: dfbe9e93d3c7
Revises: 9cf5c00f72fe
Create Date: 2024-11-16 07:54:18.727906
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "dfbe9e93d3c7"
down_revision = "9cf5c00f72fe"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
"""
UPDATE "user"
SET role = 'EXT_PERM_USER'
WHERE has_web_login = false
"""
)
op.drop_column("user", "has_web_login")
def downgrade() -> None:
op.add_column(
"user",
sa.Column("has_web_login", sa.Boolean(), nullable=False, server_default="true"),
)
op.execute(
"""
UPDATE "user"
SET has_web_login = false,
role = 'BASIC'
WHERE role IN ('SLACK_USER', 'EXT_PERM_USER')
"""
)

View File

@@ -1,6 +1,5 @@
import asyncio
from logging.config import fileConfig
from typing import Literal
from sqlalchemy import pool
from sqlalchemy.engine import Connection
@@ -38,15 +37,8 @@ EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
def include_object(
object: SchemaItem,
name: str | None,
type_: Literal[
"schema",
"table",
"column",
"index",
"unique_constraint",
"foreign_key_constraint",
],
name: str,
type_: str,
reflected: bool,
compare_to: SchemaItem | None,
) -> bool:

View File

@@ -16,46 +16,6 @@ class ExternalAccess:
is_public: bool
@dataclass(frozen=True)
class DocExternalAccess:
"""
This is just a class to wrap the external access and the document ID
together. It's used for syncing document permissions to Redis.
"""
external_access: ExternalAccess
# The document ID
doc_id: str
def to_dict(self) -> dict:
return {
"external_access": {
"external_user_emails": list(self.external_access.external_user_emails),
"external_user_group_ids": list(
self.external_access.external_user_group_ids
),
"is_public": self.external_access.is_public,
},
"doc_id": self.doc_id,
}
@classmethod
def from_dict(cls, data: dict) -> "DocExternalAccess":
external_access = ExternalAccess(
external_user_emails=set(
data["external_access"].get("external_user_emails", [])
),
external_user_group_ids=set(
data["external_access"].get("external_user_group_ids", [])
),
is_public=data["external_access"]["is_public"],
)
return cls(
external_access=external_access,
doc_id=data["doc_id"],
)
@dataclass(frozen=True)
class DocumentAccess(ExternalAccess):
# User emails for Danswer users, None indicates admin

View File

@@ -1,100 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from danswer.agent_search.answer_query.nodes.answer_check import answer_check
from danswer.agent_search.answer_query.nodes.answer_generation import answer_generation
from danswer.agent_search.answer_query.nodes.format_answer import format_answer
from danswer.agent_search.answer_query.states import AnswerQueryInput
from danswer.agent_search.answer_query.states import AnswerQueryOutput
from danswer.agent_search.answer_query.states import AnswerQueryState
from danswer.agent_search.expanded_retrieval.graph_builder import (
expanded_retrieval_graph_builder,
)
def answer_query_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=AnswerQueryState,
input=AnswerQueryInput,
output=AnswerQueryOutput,
)
### Add nodes ###
expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="expanded_retrieval_for_initial_decomp",
action=expanded_retrieval,
)
graph.add_node(
node="answer_check",
action=answer_check,
)
graph.add_node(
node="answer_generation",
action=answer_generation,
)
graph.add_node(
node="format_answer",
action=format_answer,
)
### Add edges ###
graph.add_edge(
start_key=START,
end_key="expanded_retrieval_for_initial_decomp",
)
graph.add_edge(
start_key="expanded_retrieval_for_initial_decomp",
end_key="answer_generation",
)
graph.add_edge(
start_key="answer_generation",
end_key="answer_check",
)
graph.add_edge(
start_key="answer_check",
end_key="format_answer",
)
graph.add_edge(
start_key="format_answer",
end_key=END,
)
return graph
if __name__ == "__main__":
from danswer.db.engine import get_session_context_manager
from danswer.llm.factory import get_default_llms
from danswer.context.search.models import SearchRequest
graph = answer_query_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="Who made Excel and what other products did they make?",
)
with get_session_context_manager() as db_session:
inputs = AnswerQueryInput(
search_request=search_request,
primary_llm=primary_llm,
fast_llm=fast_llm,
db_session=db_session,
query_to_answer="Who made Excel?",
)
output = compiled_graph.invoke(
input=inputs,
# debug=True,
# subgraphs=True,
)
print(output)
# for namespace, chunk in compiled_graph.stream(
# input=inputs,
# # debug=True,
# subgraphs=True,
# ):
# print(namespace)
# print(chunk)

View File

@@ -1,30 +0,0 @@
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from danswer.agent_search.answer_query.states import AnswerQueryState
from danswer.agent_search.answer_query.states import QACheckOutput
from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT
def answer_check(state: AnswerQueryState) -> QACheckOutput:
msg = [
HumanMessage(
content=BASE_CHECK_PROMPT.format(
question=state["search_request"].query,
base_answer=state["answer"],
)
)
]
fast_llm = state["fast_llm"]
response = list(
fast_llm.stream(
prompt=msg,
)
)
response_str = merge_message_runs(response, chunk_separator="")[0].content
return QACheckOutput(
answer_quality=response_str,
)

View File

@@ -1,32 +0,0 @@
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from danswer.agent_search.answer_query.states import AnswerQueryState
from danswer.agent_search.answer_query.states import QAGenerationOutput
from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT
from danswer.agent_search.shared_graph_utils.utils import format_docs
def answer_generation(state: AnswerQueryState) -> QAGenerationOutput:
query = state["query_to_answer"]
docs = state["reordered_documents"]
print(f"Number of verified retrieval docs: {len(docs)}")
msg = [
HumanMessage(
content=BASE_RAG_PROMPT.format(question=query, context=format_docs(docs))
)
]
fast_llm = state["fast_llm"]
response = list(
fast_llm.stream(
prompt=msg,
)
)
answer_str = merge_message_runs(response, chunk_separator="")[0].content
return QAGenerationOutput(
answer=answer_str,
)

View File

@@ -1,16 +0,0 @@
from danswer.agent_search.answer_query.states import AnswerQueryOutput
from danswer.agent_search.answer_query.states import AnswerQueryState
from danswer.agent_search.answer_query.states import SearchAnswerResults
def format_answer(state: AnswerQueryState) -> AnswerQueryOutput:
return AnswerQueryOutput(
decomp_answer_results=[
SearchAnswerResults(
query=state["query_to_answer"],
quality=state["answer_quality"],
answer=state["answer"],
documents=state["reordered_documents"],
)
],
)

View File

@@ -1,45 +0,0 @@
from typing import Annotated
from typing import TypedDict
from pydantic import BaseModel
from danswer.agent_search.core_state import PrimaryState
from danswer.agent_search.shared_graph_utils.operators import dedup_inference_sections
from danswer.context.search.models import InferenceSection
class SearchAnswerResults(BaseModel):
query: str
answer: str
quality: str
documents: Annotated[list[InferenceSection], dedup_inference_sections]
class QACheckOutput(TypedDict, total=False):
answer_quality: str
class QAGenerationOutput(TypedDict, total=False):
answer: str
class ExpandedRetrievalOutput(TypedDict):
reordered_documents: Annotated[list[InferenceSection], dedup_inference_sections]
class AnswerQueryState(
PrimaryState,
QACheckOutput,
QAGenerationOutput,
ExpandedRetrievalOutput,
total=True,
):
query_to_answer: str
class AnswerQueryInput(PrimaryState, total=True):
query_to_answer: str
class AnswerQueryOutput(TypedDict):
decomp_answer_results: list[SearchAnswerResults]

View File

@@ -1,15 +0,0 @@
from typing import TypedDict
from sqlalchemy.orm import Session
from danswer.context.search.models import SearchRequest
from danswer.llm.interfaces import LLM
class PrimaryState(TypedDict, total=False):
search_request: SearchRequest
primary_llm: LLM
fast_llm: LLM
# a single session for the entire agent search
# is fine if we are only reading
db_session: Session

View File

@@ -1,114 +0,0 @@
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.main.states import MainState
from danswer.agent_search.shared_graph_utils.prompts import COMBINED_CONTEXT
from danswer.agent_search.shared_graph_utils.prompts import MODIFIED_RAG_PROMPT
from danswer.agent_search.shared_graph_utils.utils import format_docs
from danswer.agent_search.shared_graph_utils.utils import normalize_whitespace
# aggregate sub questions and answers
def deep_answer_generation(state: MainState) -> dict[str, Any]:
"""
Generate answer
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---DEEP GENERATE---")
question = state["original_question"]
docs = state["deduped_retrieval_docs"]
deep_answer_context = state["core_answer_dynamic_context"]
print(f"Number of verified retrieval docs - deep: {len(docs)}")
combined_context = normalize_whitespace(
COMBINED_CONTEXT.format(
deep_answer_context=deep_answer_context, formated_docs=format_docs(docs)
)
)
msg = [
HumanMessage(
content=MODIFIED_RAG_PROMPT.format(
question=question, combined_context=combined_context
)
)
]
# Grader
model = state["fast_llm"]
response = model.invoke(msg)
return {
"deep_answer": response.content,
}
def final_stuff(state: MainState) -> dict[str, Any]:
"""
Invokes the agent model to generate a response based on the current state. Given
the question, it will decide to retrieve using the retriever tool, or simply end.
Args:
state (messages): The current state
Returns:
dict: The updated state with the agent response appended to messages
"""
print("---FINAL---")
messages = state["log_messages"]
time_ordered_messages = [x.pretty_repr() for x in messages]
time_ordered_messages.sort()
print("Message Log:")
print("\n".join(time_ordered_messages))
initial_sub_qas = state["initial_sub_qas"]
initial_sub_qa_list = []
for initial_sub_qa in initial_sub_qas:
if initial_sub_qa["sub_answer_check"] == "yes":
initial_sub_qa_list.append(
f' Question:\n {initial_sub_qa["sub_question"]}\n --\n Answer:\n {initial_sub_qa["sub_answer"]}\n -----'
)
initial_sub_qa_context = "\n".join(initial_sub_qa_list)
base_answer = state["base_answer"]
print(f"Final Base Answer:\n{base_answer}")
print("--------------------------------")
print(f"Initial Answered Sub Questions:\n{initial_sub_qa_context}")
print("--------------------------------")
if not state.get("deep_answer"):
print("No Deep Answer was required")
return {}
deep_answer = state["deep_answer"]
sub_qas = state["sub_qas"]
sub_qa_list = []
for sub_qa in sub_qas:
if sub_qa["sub_answer_check"] == "yes":
sub_qa_list.append(
f' Question:\n {sub_qa["sub_question"]}\n --\n Answer:\n {sub_qa["sub_answer"]}\n -----'
)
sub_qa_context = "\n".join(sub_qa_list)
print(f"Final Base Answer:\n{base_answer}")
print("--------------------------------")
print(f"Final Deep Answer:\n{deep_answer}")
print("--------------------------------")
print("Sub Questions and Answers:")
print(sub_qa_context)
return {}

View File

@@ -1,78 +0,0 @@
import json
import re
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.main.states import MainState
from danswer.agent_search.shared_graph_utils.prompts import DEEP_DECOMPOSE_PROMPT
from danswer.agent_search.shared_graph_utils.utils import format_entity_term_extraction
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def decompose(state: MainState) -> dict[str, Any]:
""" """
node_start_time = datetime.now()
question = state["original_question"]
base_answer = state["base_answer"]
# get the entity term extraction dict and properly format it
entity_term_extraction_dict = state["retrieved_entities_relationships"][
"retrieved_entities_relationships"
]
entity_term_extraction_str = format_entity_term_extraction(
entity_term_extraction_dict
)
initial_question_answers = state["initial_sub_qas"]
addressed_question_list = [
x["sub_question"]
for x in initial_question_answers
if x["sub_answer_check"] == "yes"
]
failed_question_list = [
x["sub_question"]
for x in initial_question_answers
if x["sub_answer_check"] == "no"
]
msg = [
HumanMessage(
content=DEEP_DECOMPOSE_PROMPT.format(
question=question,
entity_term_extraction_str=entity_term_extraction_str,
base_answer=base_answer,
answered_sub_questions="\n - ".join(addressed_question_list),
failed_sub_questions="\n - ".join(failed_question_list),
),
)
]
# Grader
model = state["fast_llm"]
response = model.invoke(msg)
cleaned_response = re.sub(r"```json\n|\n```", "", response.pretty_repr())
parsed_response = json.loads(cleaned_response)
sub_questions_dict = {}
for sub_question_nr, sub_question_dict in enumerate(
parsed_response["sub_questions"]
):
sub_question_dict["answered"] = False
sub_question_dict["verified"] = False
sub_questions_dict[sub_question_nr] = sub_question_dict
return {
"decomposed_sub_questions_dict": sub_questions_dict,
"log_messages": generate_log_message(
message="deep - decompose",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -1,40 +0,0 @@
import json
import re
from typing import Any
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from danswer.agent_search.main.states import MainState
from danswer.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT
from danswer.agent_search.shared_graph_utils.utils import format_docs
def entity_term_extraction(state: MainState) -> dict[str, Any]:
"""Extract entities and terms from the question and context"""
question = state["original_question"]
docs = state["deduped_retrieval_docs"]
doc_context = format_docs(docs)
msg = [
HumanMessage(
content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context),
)
]
fast_llm = state["fast_llm"]
# Grader
llm_response_list = list(
fast_llm.stream(
prompt=msg,
)
)
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
cleaned_response = re.sub(r"```json\n|\n```", "", llm_response)
parsed_response = json.loads(cleaned_response)
return {
"retrieved_entities_relationships": parsed_response,
}

View File

@@ -1,30 +0,0 @@
from typing import Any
from danswer.agent_search.main.states import MainState
# aggregate sub questions and answers
def sub_qa_level_aggregator(state: MainState) -> dict[str, Any]:
sub_qas = state["sub_qas"]
dynamic_context_list = [
"Below you will find useful information to answer the original question:"
]
checked_sub_qas = []
for core_answer_sub_qa in sub_qas:
question = core_answer_sub_qa["sub_question"]
answer = core_answer_sub_qa["sub_answer"]
verified = core_answer_sub_qa["sub_answer_check"]
if verified == "yes":
dynamic_context_list.append(
f"Question:\n{question}\n\nAnswer:\n{answer}\n\n---\n\n"
)
checked_sub_qas.append({"sub_question": question, "sub_answer": answer})
dynamic_context = "\n".join(dynamic_context_list)
return {
"core_answer_dynamic_context": dynamic_context,
"checked_sub_qas": checked_sub_qas,
}

View File

@@ -1,19 +0,0 @@
from typing import Any
from danswer.agent_search.main.states import MainState
def sub_qa_manager(state: MainState) -> dict[str, Any]:
""" """
sub_questions_dict = state["decomposed_sub_questions_dict"]
sub_questions = {}
for sub_question_nr, sub_question_dict in sub_questions_dict.items():
sub_questions[sub_question_nr] = sub_question_dict["sub_question"]
return {
"sub_questions": sub_questions,
"num_new_question_iterations": 0,
}

View File

@@ -1,44 +0,0 @@
from collections.abc import Hashable
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from langgraph.types import Send
from danswer.agent_search.expanded_retrieval.nodes.doc_retrieval import RetrieveInput
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalInput
from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI
from danswer.llm.interfaces import LLM
def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashable]:
print(f"parallel_retrieval_edge state: {state.keys()}")
# This should be better...
question = state.get("query_to_answer") or state["search_request"].query
llm: LLM = state["fast_llm"]
msg = [
HumanMessage(
content=REWRITE_PROMPT_MULTI.format(question=question),
)
]
llm_response_list = list(
llm.stream(
prompt=msg,
)
)
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
print(f"llm_response: {llm_response}")
rewritten_queries = llm_response.split("\n")
print(f"rewritten_queries: {rewritten_queries}")
return [
Send(
"doc_retrieval",
RetrieveInput(query_to_retrieve=query, **state),
)
for query in rewritten_queries
]

View File

@@ -1,88 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from danswer.agent_search.expanded_retrieval.edges import parallel_retrieval_edge
from danswer.agent_search.expanded_retrieval.nodes.doc_reranking import doc_reranking
from danswer.agent_search.expanded_retrieval.nodes.doc_retrieval import doc_retrieval
from danswer.agent_search.expanded_retrieval.nodes.doc_verification import (
doc_verification,
)
from danswer.agent_search.expanded_retrieval.nodes.verification_kickoff import (
verification_kickoff,
)
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalInput
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState
def expanded_retrieval_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=ExpandedRetrievalState,
input=ExpandedRetrievalInput,
output=ExpandedRetrievalOutput,
)
### Add nodes ###
graph.add_node(
node="doc_retrieval",
action=doc_retrieval,
)
graph.add_node(
node="verification_kickoff",
action=verification_kickoff,
)
graph.add_node(
node="doc_verification",
action=doc_verification,
)
graph.add_node(
node="doc_reranking",
action=doc_reranking,
)
### Add edges ###
graph.add_conditional_edges(
source=START,
path=parallel_retrieval_edge,
path_map=["doc_retrieval"],
)
graph.add_edge(
start_key="doc_retrieval",
end_key="verification_kickoff",
)
graph.add_edge(
start_key="doc_verification",
end_key="doc_reranking",
)
graph.add_edge(
start_key="doc_reranking",
end_key=END,
)
return graph
if __name__ == "__main__":
from danswer.db.engine import get_session_context_manager
from danswer.llm.factory import get_default_llms
from danswer.context.search.models import SearchRequest
graph = expanded_retrieval_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="Who made Excel and what other products did they make?",
)
with get_session_context_manager() as db_session:
inputs = ExpandedRetrievalInput(
search_request=search_request,
primary_llm=primary_llm,
fast_llm=fast_llm,
db_session=db_session,
query_to_answer="Who made Excel?",
)
for thing in compiled_graph.stream(inputs, debug=True):
print(thing)

View File

@@ -1,11 +0,0 @@
from danswer.agent_search.expanded_retrieval.states import DocRerankingOutput
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState
def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingOutput:
print(f"doc_reranking state: {state.keys()}")
verified_documents = state["verified_documents"]
reranked_documents = verified_documents
return DocRerankingOutput(reranked_documents=reranked_documents)

View File

@@ -1,47 +0,0 @@
from danswer.agent_search.expanded_retrieval.states import DocRetrievalOutput
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import SearchRequest
from danswer.context.search.pipeline import SearchPipeline
from danswer.db.engine import get_session_context_manager
class RetrieveInput(ExpandedRetrievalState):
query_to_retrieve: str
def doc_retrieval(state: RetrieveInput) -> DocRetrievalOutput:
# def doc_retrieval(state: RetrieveInput) -> Command[Literal["doc_verification"]]:
"""
Retrieve documents
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
print(f"doc_retrieval state: {state.keys()}")
state["query_to_retrieve"]
documents: list[InferenceSection] = []
llm = state["primary_llm"]
fast_llm = state["fast_llm"]
# db_session = state["db_session"]
query_to_retrieve = state["search_request"].query
with get_session_context_manager() as db_session1:
documents = SearchPipeline(
search_request=SearchRequest(
query=query_to_retrieve,
),
user=None,
llm=llm,
fast_llm=fast_llm,
db_session=db_session1,
).reranked_sections
print(f"retrieved documents: {len(documents)}")
return DocRetrievalOutput(
retrieved_documents=documents,
)

View File

@@ -1,60 +0,0 @@
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from danswer.agent_search.expanded_retrieval.states import DocVerificationOutput
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
from danswer.context.search.models import InferenceSection
class DocVerificationInput(ExpandedRetrievalState, total=True):
doc_to_verify: InferenceSection
def doc_verification(state: DocVerificationInput) -> DocVerificationOutput:
"""
Check whether the document is relevant for the original user question
Args:
state (VerifierState): The current state
Returns:
dict: ict: The updated state with the final decision
"""
print(f"doc_verification state: {state.keys()}")
original_query = state["search_request"].query
doc_to_verify = state["doc_to_verify"]
document_content = doc_to_verify.combined_content
msg = [
HumanMessage(
content=VERIFIER_PROMPT.format(
question=original_query, document_content=document_content
)
)
]
fast_llm = state["fast_llm"]
response = list(
fast_llm.stream(
prompt=msg,
)
)
response_string = merge_message_runs(response, chunk_separator="")[0].content
# Convert string response to proper dictionary format
decision_dict = {"decision": response_string.lower()}
formatted_response = BinaryDecision.model_validate(decision_dict)
print(f"Verdict: {formatted_response.decision}")
verified_documents = []
if formatted_response.decision == "yes":
verified_documents.append(doc_to_verify)
return DocVerificationOutput(
verified_documents=verified_documents,
)

View File

@@ -1,27 +0,0 @@
from typing import Literal
from langgraph.types import Command
from langgraph.types import Send
from danswer.agent_search.expanded_retrieval.nodes.doc_verification import (
DocVerificationInput,
)
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState
def verification_kickoff(
state: ExpandedRetrievalState,
) -> Command[Literal["doc_verification"]]:
print(f"verification_kickoff state: {state.keys()}")
documents = state["retrieved_documents"]
return Command(
update={},
goto=[
Send(
node="doc_verification",
arg=DocVerificationInput(doc_to_verify=doc, **state),
)
for doc in documents
],
)

View File

@@ -1,36 +0,0 @@
from typing import Annotated
from typing import TypedDict
from danswer.agent_search.core_state import PrimaryState
from danswer.agent_search.shared_graph_utils.operators import dedup_inference_sections
from danswer.context.search.models import InferenceSection
class DocRetrievalOutput(TypedDict, total=False):
retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections]
class DocVerificationOutput(TypedDict, total=False):
verified_documents: Annotated[list[InferenceSection], dedup_inference_sections]
class DocRerankingOutput(TypedDict, total=False):
reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections]
class ExpandedRetrievalState(
PrimaryState,
DocRetrievalOutput,
DocVerificationOutput,
DocRerankingOutput,
total=True,
):
query_to_answer: str
class ExpandedRetrievalInput(PrimaryState, total=True):
query_to_answer: str
class ExpandedRetrievalOutput(TypedDict):
reordered_documents: Annotated[list[InferenceSection], dedup_inference_sections]

View File

@@ -1,61 +0,0 @@
from collections.abc import Hashable
from langgraph.types import Send
from danswer.agent_search.answer_query.states import AnswerQueryInput
from danswer.agent_search.main.states import MainState
def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hashable]:
return [
Send(
"answer_query",
AnswerQueryInput(
**state,
query_to_answer=query,
),
)
for query in state["initial_decomp_queries"]
]
# def continue_to_answer_sub_questions(state: QAState) -> Union[Hashable, list[Hashable]]:
# # Routes re-written queries to the (parallel) retrieval steps
# # Notice the 'Send()' API that takes care of the parallelization
# return [
# Send(
# "sub_answers_graph",
# ResearchQAState(
# sub_question=sub_question["sub_question_str"],
# sub_question_nr=sub_question["sub_question_nr"],
# graph_start_time=state["graph_start_time"],
# primary_llm=state["primary_llm"],
# fast_llm=state["fast_llm"],
# ),
# )
# for sub_question in state["sub_questions"]
# ]
# def continue_to_deep_answer(state: QAState) -> Union[Hashable, list[Hashable]]:
# print("---GO TO DEEP ANSWER OR END---")
# base_answer = state["base_answer"]
# question = state["original_question"]
# BASE_CHECK_MESSAGE = [
# HumanMessage(
# content=BASE_CHECK_PROMPT.format(question=question, base_answer=base_answer)
# )
# ]
# model = state["fast_llm"]
# response = model.invoke(BASE_CHECK_MESSAGE)
# print(f"CAN WE CONTINUE W/O GENERATING A DEEP ANSWER? - {response.pretty_repr()}")
# if response.pretty_repr() == "no":
# return "decompose"
# else:
# return "end"

View File

@@ -1,98 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from danswer.agent_search.answer_query.graph_builder import answer_query_graph_builder
from danswer.agent_search.expanded_retrieval.graph_builder import (
expanded_retrieval_graph_builder,
)
from danswer.agent_search.main.edges import parallelize_decompozed_answer_queries
from danswer.agent_search.main.nodes.base_decomp import main_decomp_base
from danswer.agent_search.main.nodes.generate_initial_answer import (
generate_initial_answer,
)
from danswer.agent_search.main.states import MainInput
from danswer.agent_search.main.states import MainState
def main_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=MainState,
input=MainInput,
)
### Add nodes ###
graph.add_node(
node="base_decomp",
action=main_decomp_base,
)
answer_query_subgraph = answer_query_graph_builder().compile()
graph.add_node(
node="answer_query",
action=answer_query_subgraph,
)
expanded_retrieval_subgraph = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="expanded_retrieval",
action=expanded_retrieval_subgraph,
)
graph.add_node(
node="generate_initial_answer",
action=generate_initial_answer,
)
### Add edges ###
graph.add_edge(
start_key=START,
end_key="expanded_retrieval",
)
graph.add_edge(
start_key=START,
end_key="base_decomp",
)
graph.add_conditional_edges(
source="base_decomp",
path=parallelize_decompozed_answer_queries,
path_map=["answer_query"],
)
graph.add_edge(
start_key=["answer_query", "expanded_retrieval"],
end_key="generate_initial_answer",
)
graph.add_edge(
start_key="generate_initial_answer",
end_key=END,
)
return graph
if __name__ == "__main__":
from danswer.db.engine import get_session_context_manager
from danswer.llm.factory import get_default_llms
from danswer.context.search.models import SearchRequest
graph = main_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="If i am familiar with the function that I need, how can I type it into a cell?",
)
with get_session_context_manager() as db_session:
inputs = MainInput(
search_request=search_request,
primary_llm=primary_llm,
fast_llm=fast_llm,
db_session=db_session,
)
for thing in compiled_graph.stream(
input=inputs,
# stream_mode="debug",
# debug=True,
subgraphs=True,
):
# print(thing)
print()
print()

View File

@@ -1,31 +0,0 @@
from langchain_core.messages import HumanMessage
from danswer.agent_search.main.states import BaseDecompOutput
from danswer.agent_search.main.states import MainState
from danswer.agent_search.shared_graph_utils.prompts import INITIAL_DECOMPOSITION_PROMPT
from danswer.agent_search.shared_graph_utils.utils import clean_and_parse_list_string
def main_decomp_base(state: MainState) -> BaseDecompOutput:
question = state["search_request"].query
msg = [
HumanMessage(
content=INITIAL_DECOMPOSITION_PROMPT.format(question=question),
)
]
# Get the rewritten queries in a defined format
model = state["fast_llm"]
response = model.invoke(msg)
content = response.pretty_repr()
list_of_subquestions = clean_and_parse_list_string(content)
decomp_list: list[str] = [
sub_question["sub_question"].strip() for sub_question in list_of_subquestions
]
return BaseDecompOutput(
initial_decomp_queries=decomp_list,
)

View File

@@ -1,53 +0,0 @@
from langchain_core.messages import HumanMessage
from danswer.agent_search.main.states import InitialAnswerOutput
from danswer.agent_search.main.states import MainState
from danswer.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT
from danswer.agent_search.shared_graph_utils.utils import format_docs
def generate_initial_answer(state: MainState) -> InitialAnswerOutput:
print("---GENERATE INITIAL---")
question = state["search_request"].query
docs = state["documents"]
decomp_answer_results = state["decomp_answer_results"]
good_qa_list: list[str] = []
_SUB_QUESTION_ANSWER_TEMPLATE = """
Sub-Question:\n - {sub_question}\n --\nAnswer:\n - {sub_answer}\n\n
"""
for decomp_answer_result in decomp_answer_results:
if (
decomp_answer_result.quality.lower() == "yes"
and len(decomp_answer_result.answer) > 0
and decomp_answer_result.answer != "I don't know"
):
good_qa_list.append(
_SUB_QUESTION_ANSWER_TEMPLATE.format(
sub_question=decomp_answer_result.query,
sub_answer=decomp_answer_result.answer,
)
)
sub_question_answer_str = "\n\n------\n\n".join(good_qa_list)
msg = [
HumanMessage(
content=INITIAL_RAG_PROMPT.format(
question=question,
context=format_docs(docs),
answered_sub_questions=sub_question_answer_str,
)
)
]
# Grader
model = state["fast_llm"]
response = model.invoke(msg)
answer = response.pretty_repr()
print(answer)
return InitialAnswerOutput(initial_answer=answer)

View File

@@ -1,37 +0,0 @@
from operator import add
from typing import Annotated
from typing import TypedDict
from danswer.agent_search.answer_query.states import SearchAnswerResults
from danswer.agent_search.core_state import PrimaryState
from danswer.agent_search.shared_graph_utils.operators import dedup_inference_sections
from danswer.context.search.models import InferenceSection
class BaseDecompOutput(TypedDict, total=False):
initial_decomp_queries: list[str]
class InitialAnswerOutput(TypedDict, total=False):
initial_answer: str
class MainState(
PrimaryState,
BaseDecompOutput,
InitialAnswerOutput,
total=True,
):
documents: Annotated[list[InferenceSection], dedup_inference_sections]
decomp_answer_results: Annotated[list[SearchAnswerResults], add]
class MainInput(PrimaryState, total=True):
pass
class MainOutput(TypedDict):
"""
This is not used because defining the output only matters for filtering the output of
a .invoke() call but we are streaming so we just yield the entire state.
"""

View File

@@ -1,27 +0,0 @@
from danswer.agent_search.primary_graph.graph_builder import build_core_graph
from danswer.llm.answering.answer import AnswerStream
from danswer.llm.interfaces import LLM
from danswer.tools.tool import Tool
def run_graph(
query: str,
llm: LLM,
tools: list[Tool],
) -> AnswerStream:
graph = build_core_graph()
inputs = {
"original_query": query,
"messages": [],
"tools": tools,
"llm": llm,
}
compiled_graph = graph.compile()
output = compiled_graph.invoke(input=inputs)
yield from output
if __name__ == "__main__":
pass
# run_graph("What is the capital of France?", llm, [])

View File

@@ -1,12 +0,0 @@
from typing import Literal
from pydantic import BaseModel
# Pydantic models for structured outputs
class RewrittenQueries(BaseModel):
rewritten_queries: list[str]
class BinaryDecision(BaseModel):
decision: Literal["yes", "no"]

View File

@@ -1,9 +0,0 @@
from danswer.context.search.models import InferenceSection
from danswer.llm.answering.prune_and_merge import _merge_sections
def dedup_inference_sections(
list1: list[InferenceSection], list2: list[InferenceSection]
) -> list[InferenceSection]:
deduped = _merge_sections(list1 + list2)
return deduped

View File

@@ -1,427 +0,0 @@
REWRITE_PROMPT_MULTI_ORIGINAL = """ \n
Please convert an initial user question into a 2-3 more appropriate short and pointed search queries for retrievel from a
document store. Particularly, try to think about resolving ambiguities and make the search queries more specific,
enabling the system to search more broadly.
Also, try to make the search queries not redundant, i.e. not too similar! \n\n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Formulate the queries separated by '--' (Do not say 'Query 1: ...', just write the querytext): """
REWRITE_PROMPT_MULTI = """ \n
Please create a list of 2-3 sample documents that could answer an original question. Each document
should be about as long as the original question. \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Formulate the sample documents separated by '--' (Do not say 'Document 1: ...', just write the text): """
BASE_RAG_PROMPT = """ \n
You are an assistant for question-answering tasks. Use the context provided below - and only the
provided context - to answer the question. If you don't know the answer or if the provided context is
empty, just say "I don't know". Do not use your internal knowledge!
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
question based on the context, say "I don't know". It is a matter of life and death that you do NOT
use your internal knowledge, just the provided information!
Use three sentences maximum and keep the answer concise.
answer concise.\nQuestion:\n {question} \nContext:\n {context} \n\n
\n\n
Answer:"""
BASE_CHECK_PROMPT = """ \n
Please check whether 1) the suggested answer seems to fully address the original question AND 2)the
original question requests a simple, factual answer, and there are no ambiguities, judgements,
aggregations, or any other complications that may require extra context. (I.e., if the question is
somewhat addressed, but the answer would benefit from more context, then answer with 'no'.)
Please only answer with 'yes' or 'no' \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Here is the proposed answer:
\n ------- \n
{base_answer}
\n ------- \n
Please answer with yes or no:"""
VERIFIER_PROMPT = """ \n
Please check whether the document seems to be relevant for the answer of the question. Please
only answer with 'yes' or 'no' \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Here is the document text:
\n ------- \n
{document_content}
\n ------- \n
Please answer with yes or no:"""
INITIAL_DECOMPOSITION_PROMPT_BASIC = """ \n
Please decompose an initial user question into not more than 4 appropriate sub-questions that help to
answer the original question. The purpose for this decomposition is to isolate individulal entities
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Please formulate your answer as a list of subquestions:
Answer:
"""
REWRITE_PROMPT_SINGLE = """ \n
Please convert an initial user question into a more appropriate search query for retrievel from a
document store. \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Formulate the query: """
MODIFIED_RAG_PROMPT = """You are an assistant for question-answering tasks. Use the context provided below
- and only this context - to answer the question. If you don't know the answer, just say "I don't know".
Use three sentences maximum and keep the answer concise.
Pay also particular attention to the sub-questions and their answers, at least it may enrich the answer.
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
question based on the context, say "I don't know". It is a matter of life and death that you do NOT
use your internal knowledge, just the provided information!
\nQuestion: {question}
\nContext: {combined_context} \n
Answer:"""
ORIG_DEEP_DECOMPOSE_PROMPT = """ \n
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
good enough. Also, some sub-questions had been answered and this information has been used to provide
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
you have an idea of how the avaiolable data looks like.
Your role is to generate 3-5 new sub-questions that would help to answer the initial question,
considering:
1) The initial question
2) The initial answer that was found to be unsatisfactory
3) The sub-questions that were answered
4) The sub-questions that were suggested but not answered
5) The entities, relationships and terms that were extracted from the context
The individual questions should be answerable by a good RAG system.
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
question for different entities that may be involved in the original question, but in a way that does
not duplicate questions that were already tried.
Additional Guidelines:
- The sub-questions should be specific to the question and provide richer context for the question,
resolve ambiguities, or address shortcoming of the initial answer
- Each sub-question - when answered - should be relevant for the answer to the original question
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
other complications that may require extra context.
- The sub-questions MUST have the full context of the original question so that it can be executed by
a RAG system independently without the original question available
(Example:
- initial question: "What is the capital of France?"
- bad sub-question: "What is the name of the river there?"
- good sub-question: "What is the name of the river that flows through Paris?"
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
generate a list of dictionaries with the following format:
[{{"sub_question": <sub-question>, "explanation": <explanation>, "search_term": <rewrite the
sub-question using as a search phrase for the document store>}}, ...]
\n\n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Here is the initial sub-optimal answer:
\n ------- \n
{base_answer}
\n ------- \n
Here are the sub-questions that were answered:
\n ------- \n
{answered_sub_questions}
\n ------- \n
Here are the sub-questions that were suggested but not answered:
\n ------- \n
{failed_sub_questions}
\n ------- \n
And here are the entities, relationships and terms extracted from the context:
\n ------- \n
{entity_term_extraction_str}
\n ------- \n
Please generate the list of good, fully contextualized sub-questions that would help to address the
main question. Again, please find questions that are NOT overlapping too much with the already answered
sub-questions or those that already were suggested and failed.
In other words - what can we try in addition to what has been tried so far?
Please think through it step by step and then generate the list of json dictionaries with the following
format:
{{"sub_questions": [{{"sub_question": <sub-question>,
"explanation": <explanation>,
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
...]}} """
DEEP_DECOMPOSE_PROMPT = """ \n
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
good enough. Also, some sub-questions had been answered and this information has been used to provide
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
you have an idea of how the avaiolable data looks like.
Your role is to generate 4-6 new sub-questions that would help to answer the initial question,
considering:
1) The initial question
2) The initial answer that was found to be unsatisfactory
3) The sub-questions that were answered
4) The sub-questions that were suggested but not answered
5) The entities, relationships and terms that were extracted from the context
The individual questions should be answerable by a good RAG system.
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
question for different entities that may be involved in the original question, but in a way that does
not duplicate questions that were already tried.
Additional Guidelines:
- The sub-questions should be specific to the question and provide richer context for the question,
resolve ambiguities, or address shortcoming of the initial answer
- Each sub-question - when answered - should be relevant for the answer to the original question
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
other complications that may require extra context.
- The sub-questions MUST have the full context of the original question so that it can be executed by
a RAG system independently without the original question available
(Example:
- initial question: "What is the capital of France?"
- bad sub-question: "What is the name of the river there?"
- good sub-question: "What is the name of the river that flows through Paris?"
- For each sub-question, please also provide a search term that can be used to retrieve relevant
documents from a document store.
\n\n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Here is the initial sub-optimal answer:
\n ------- \n
{base_answer}
\n ------- \n
Here are the sub-questions that were answered:
\n ------- \n
{answered_sub_questions}
\n ------- \n
Here are the sub-questions that were suggested but not answered:
\n ------- \n
{failed_sub_questions}
\n ------- \n
And here are the entities, relationships and terms extracted from the context:
\n ------- \n
{entity_term_extraction_str}
\n ------- \n
Please generate the list of good, fully contextualized sub-questions that would help to address the
main question. Again, please find questions that are NOT overlapping too much with the already answered
sub-questions or those that already were suggested and failed.
In other words - what can we try in addition to what has been tried so far?
Generate the list of json dictionaries with the following format:
{{"sub_questions": [{{"sub_question": <sub-question>,
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
...]}} """
DECOMPOSE_PROMPT = """ \n
For an initial user question, please generate at 5-10 individual sub-questions whose answers would help
\n to answer the initial question. The individual questions should be answerable by a good RAG system.
So a good idea would be to \n use the sub-questions to resolve ambiguities and/or to separate the
question for different entities that may be involved in the original question.
In order to arrive at meaningful sub-questions, please also consider the context retrieved from the
document store, expressed as entities, relationships and terms. You can also think about the types
mentioned in brackets
Guidelines:
- The sub-questions should be specific to the question and provide richer context for the question,
and or resolve ambiguities
- Each sub-question - when answered - should be relevant for the answer to the original question
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
other complications that may require extra context.
- The sub-questions MUST have the full context of the original question so that it can be executed by
a RAG system independently without the original question available
(Example:
- initial question: "What is the capital of France?"
- bad sub-question: "What is the name of the river there?"
- good sub-question: "What is the name of the river that flows through Paris?"
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
generate a list of dictionaries with the following format:
[{{"sub_question": <sub-question>, "explanation": <explanation>, "search_term": <rewrite the
sub-question using as a search phrase for the document store>}}, ...]
\n\n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
And here are the entities, relationships and terms extracted from the context:
\n ------- \n
{entity_term_extraction_str}
\n ------- \n
Please generate the list of good, fully contextualized sub-questions that would help to address the
main question. Don't be too specific unless the original question is specific.
Please think through it step by step and then generate the list of json dictionaries with the following
format:
{{"sub_questions": [{{"sub_question": <sub-question>,
"explanation": <explanation>,
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
...]}} """
#### Consolidations
COMBINED_CONTEXT = """-------
Below you will find useful information to answer the original question. First, you see a number of
sub-questions with their answers. This information should be considered to be more focussed and
somewhat more specific to the original question as it tries to contextualized facts.
After that will see the documents that were considered to be relevant to answer the original question.
Here are the sub-questions and their answers:
\n\n {deep_answer_context} \n\n
\n\n Here are the documents that were considered to be relevant to answer the original question:
\n\n {formated_docs} \n\n
----------------
"""
SUB_QUESTION_EXPLANATION_RANKER_PROMPT = """-------
Below you will find a question that we ultimately want to answer (the original question) and a list of
motivations in arbitrary order for generated sub-questions that are supposed to help us answering the
original question. The motivations are formatted as <motivation number>: <motivation explanation>.
(Again, the numbering is arbitrary and does not necessarily mean that 1 is the most relevant
motivation and 2 is less relevant.)
Please rank the motivations in order of relevance for answering the original question. Also, try to
ensure that the top questions do not duplicate too much, i.e. that they are not too similar.
Ultimately, create a list with the motivation numbers where the number of the most relevant
motivations comes first.
Here is the original question:
\n\n {original_question} \n\n
\n\n Here is the list of sub-question motivations:
\n\n {sub_question_explanations} \n\n
----------------
Please think step by step and then generate the ranked list of motivations.
Please format your answer as a json object in the following format:
{{"reasonning": <explain your reasoning for the ranking>,
"ranked_motivations": <ranked list of motivation numbers>}}
"""
INITIAL_DECOMPOSITION_PROMPT = """ \n
Please decompose an initial user question into 2 or 3 appropriate sub-questions that help to
answer the original question. The purpose for this decomposition is to isolate individulal entities
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n
For each sub-question, please also create one search term that can be used to retrieve relevant
documents from a document store.
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Please formulate your answer as a list of json objects with the following format:
[{{"sub_question": <sub-question>, "search_term": <search term>}}, ...]
Answer:
"""
INITIAL_RAG_PROMPT = """ \n
You are an assistant for question-answering tasks. Use the information provided below - and only the
provided information - to answer the provided question.
The information provided below consists of:
1) a number of answered sub-questions - these are very important(!) and definitely should be
considered to answer the question.
2) a number of documents that were also deemed relevant for the question.
If you don't know the answer or if the provided information is empty or insufficient, just say
"I don't know". Do not use your internal knowledge!
Again, only use the provided informationand do not use your internal knowledge! It is a matter of life
and death that you do NOT use your internal knowledge, just the provided information!
Try to keep your answer concise.
And here is the question and the provided information:
\n
\nQuestion:\n {question}
\nAnswered Sub-questions:\n {answered_sub_questions}
\nContext:\n {context} \n\n
\n\n
Answer:"""
ENTITY_TERM_PROMPT = """ \n
Based on the original question and the context retieved from a dataset, please generate a list of
entities (e.g. companies, organizations, industries, products, locations, etc.), terms and concepts
(e.g. sales, revenue, etc.) that are relevant for the question, plus their relations to each other.
\n\n
Here is the original question:
\n ------- \n
{question}
\n ------- \n
And here is the context retrieved:
\n ------- \n
{context}
\n ------- \n
Please format your answer as a json object in the following format:
{{"retrieved_entities_relationships": {{
"entities": [{{
"entity_name": <assign a name for the entity>,
"entity_type": <specify a short type name for the entity, such as 'company', 'location',...>
}}],
"relationships": [{{
"name": <assign a name for the relationship>,
"type": <specify a short type name for the relationship, such as 'sales_to', 'is_location_of',...>,
"entities": [<related entity name 1>, <related entity name 2>]
}}],
"terms": [{{
"term_name": <assign a name for the term>,
"term_type": <specify a short type name for the term, such as 'revenue', 'market_share',...>,
"similar_to": <list terms that are similar to this term>
}}]
}}
}}
"""

View File

@@ -1,101 +0,0 @@
import ast
import json
import re
from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from typing import Any
from danswer.context.search.models import InferenceSection
def normalize_whitespace(text: str) -> str:
"""Normalize whitespace in text to single spaces and strip leading/trailing whitespace."""
import re
return re.sub(r"\s+", " ", text.strip())
# Post-processing
def format_docs(docs: Sequence[InferenceSection]) -> str:
return "\n\n".join(doc.combined_content for doc in docs)
def clean_and_parse_list_string(json_string: str) -> list[dict]:
# Remove any prefixes/labels before the actual JSON content
json_string = re.sub(r"^.*?(?=\[)", "", json_string, flags=re.DOTALL)
# Remove markdown code block markers and any newline prefixes
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
cleaned_string = " ".join(cleaned_string.split())
# Try parsing with json.loads first, fall back to ast.literal_eval
try:
return json.loads(cleaned_string)
except json.JSONDecodeError:
try:
return ast.literal_eval(cleaned_string)
except (ValueError, SyntaxError) as e:
raise ValueError(f"Failed to parse JSON string: {cleaned_string}") from e
def clean_and_parse_json_string(json_string: str) -> dict[str, Any]:
# Remove markdown code block markers and any newline prefixes
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
cleaned_string = " ".join(cleaned_string.split())
# Parse the cleaned string into a Python dictionary
return json.loads(cleaned_string)
def format_entity_term_extraction(entity_term_extraction_dict: dict[str, Any]) -> str:
entities = entity_term_extraction_dict["entities"]
terms = entity_term_extraction_dict["terms"]
relationships = entity_term_extraction_dict["relationships"]
entity_strs = ["\nEntities:\n"]
for entity in entities:
entity_str = f"{entity['entity_name']} ({entity['entity_type']})"
entity_strs.append(entity_str)
entity_str = "\n - ".join(entity_strs)
relationship_strs = ["\n\nRelationships:\n"]
for relationship in relationships:
relationship_str = f"{relationship['name']} ({relationship['type']}): {relationship['entities']}"
relationship_strs.append(relationship_str)
relationship_str = "\n - ".join(relationship_strs)
term_strs = ["\n\nTerms:\n"]
for term in terms:
term_str = f"{term['term_name']} ({term['term_type']}): similar to {term['similar_to']}"
term_strs.append(term_str)
term_str = "\n - ".join(term_strs)
return "\n".join(entity_strs + relationship_strs + term_strs)
def _format_time_delta(time: timedelta) -> str:
seconds_from_start = f"{((time).seconds):03d}"
microseconds_from_start = f"{((time).microseconds):06d}"
return f"{seconds_from_start}.{microseconds_from_start}"
def generate_log_message(
message: str,
node_start_time: datetime,
graph_start_time: datetime | None = None,
) -> str:
current_time = datetime.now()
if graph_start_time is not None:
graph_time_str = _format_time_delta(current_time - graph_start_time)
else:
graph_time_str = "N/A"
node_time_str = _format_time_delta(current_time - node_start_time)
return f"{graph_time_str} ({node_time_str} s): {message}"

View File

@@ -2,8 +2,8 @@ from typing import cast
from danswer.configs.constants import KV_USER_STORE_KEY
from danswer.key_value_store.factory import get_kv_store
from danswer.key_value_store.interface import JSON_ro
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.utils.special_types import JSON_ro
def get_invited_users() -> list[str]:

View File

@@ -23,9 +23,7 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
)
return UserPreferences(**preferences_data)
except KvKeyNotFoundError:
return UserPreferences(
chosen_assistants=None, default_model=None, auto_scroll=True
)
return UserPreferences(chosen_assistants=None, default_model=None)
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:

View File

@@ -13,24 +13,12 @@ class UserRole(str, Enum):
groups they are curators of
- Global Curator can perform admin actions
for all groups they are a member of
- Limited can access a limited set of basic api endpoints
- Slack are users that have used danswer via slack but dont have a web login
- External permissioned users that have been picked up during the external permissions sync process but don't have a web login
"""
LIMITED = "limited"
BASIC = "basic"
ADMIN = "admin"
CURATOR = "curator"
GLOBAL_CURATOR = "global_curator"
SLACK_USER = "slack_user"
EXT_PERM_USER = "ext_perm_user"
def is_web_login(self) -> bool:
return self not in [
UserRole.SLACK_USER,
UserRole.EXT_PERM_USER,
]
class UserStatus(str, Enum):
@@ -45,8 +33,10 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
has_web_login: bool | None = True
tenant_id: str | None = None
class UserUpdate(schemas.BaseUserUpdate):
role: UserRole
has_web_login: bool | None = True

View File

@@ -48,10 +48,11 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import attributes
from sqlalchemy.orm import Session
from danswer.auth.api_key import get_hashed_api_key_from_request
from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
@@ -74,28 +75,28 @@ from danswer.configs.constants import AuthType
from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
from danswer.db.api_key import fetch_user_for_api_key
from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.auth import SQLAlchemyUserAdminDB
from danswer.db.engine import get_async_session
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.engine import get_session
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User
from danswer.db.models import UserTenantMapping
from danswer.db.users import get_user_by_email
from danswer.server.utils import BasicAuthenticationError
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import async_return_default_schema
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
@@ -189,6 +190,20 @@ def verify_email_domain(email: str) -> None:
)
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
# Implement logic to get tenant_id from the mapping table
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
)
tenant_id = result.scalar_one_or_none()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id
def send_user_verification_email(
user_email: str,
token: str,
@@ -217,26 +232,25 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
user_db: SQLAlchemyUserDatabase[User, uuid.UUID]
async def create(
self,
user_create: schemas.UC | UserCreate,
safe: bool = False,
request: Optional[Request] = None,
) -> User:
referral_source = None
if request is not None:
referral_source = request.cookies.get("referral_source", None)
try:
tenant_id = (
get_tenant_id_for_email(user_create.email)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=user_create.email,
referral_source=referral_source,
)
if not tenant_id:
raise HTTPException(
status_code=401, detail="User does not belong to an organization"
)
async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -244,9 +258,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
db_session, User, OAuthAccount
)
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
self.user_db = tenant_user_db
self.database = tenant_user_db
@@ -259,15 +271,20 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
user = None
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
if not user.role.is_web_login() and user_create.role.is_web_login():
if (
not user.has_web_login
and hasattr(user_create, "has_web_login")
and user_create.has_web_login
):
user_update = UserUpdate(
password=user_create.password,
has_web_login=True,
role=user_create.role,
is_verified=user_create.is_verified,
)
@@ -275,13 +292,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
else:
raise exceptions.UserAlreadyExists()
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return user
async def oauth_callback(
self,
self: "BaseUserManager[models.UOAP, models.ID]",
oauth_name: str,
access_token: str,
account_id: str,
@@ -292,24 +307,20 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
*,
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> User:
referral_source = None
if request:
referral_source = getattr(request.state, "referral_source", None)
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=account_email,
referral_source=referral_source,
)
) -> models.UOAP:
# Get tenant_id from mapping table
try:
tenant_id = (
get_tenant_id_for_email(account_email)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
if not tenant_id:
raise HTTPException(status_code=401, detail="User not found")
# Proceed with the tenant context
token = None
async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -318,11 +329,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
verify_email_domain(account_email)
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
db_session, User, OAuthAccount
)
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
self.user_db = tenant_user_db
self.database = tenant_user_db
self.database = tenant_user_db # type: ignore
oauth_account_dict = {
"oauth_name": oauth_name,
@@ -362,9 +371,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# Explicitly set the Postgres schema for this session to ensure
# OAuth account creation happens in the correct tenant schema
await db_session.execute(text(f'SET search_path = "{tenant_id}"'))
# Add OAuth account
await self.user_db.add_oauth_account(user, oauth_account_dict)
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
await self.on_after_register(user, request)
else:
@@ -374,11 +383,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
and existing_oauth_account.oauth_name == oauth_name
):
user = await self.user_db.update_oauth_account(
user,
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
# but the type checker doesn't know that :(
existing_oauth_account, # type: ignore
oauth_account_dict,
user, existing_oauth_account, oauth_account_dict
)
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
@@ -391,15 +396,16 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
)
# Handle case where user has used product outside of web and is now creating an account through web
if not user.role.is_web_login():
if not user.has_web_login: # type: ignore
await self.user_db.update(
user,
{
"is_verified": is_verified_by_default,
"role": UserRole.BASIC,
"has_web_login": True,
},
)
user.is_verified = is_verified_by_default
user.has_web_login = True # type: ignore
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
# otherwise, the oidc expiry will always be old, and the user will never be able to login
@@ -447,13 +453,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
email = credentials.username
# Get tenant_id from mapping table
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=email,
)
tenant_id = get_tenant_id_for_email(email)
if not tenant_id:
# User not found in mapping
self.password_helper.hash(credentials.password)
@@ -474,8 +474,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
self.password_helper.hash(credentials.password)
return None
if not user.role.is_web_login():
raise BasicAuthenticationError(
has_web_login = attributes.get_attribute(user, "has_web_login")
if not has_web_login:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
)
@@ -507,30 +510,19 @@ cookie_transport = CookieTransport(
# This strategy is used to add tenant_id to the JWT token
class TenantAwareJWTStrategy(JWTStrategy):
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=user.email,
)
async def write_token(self, user: User) -> str:
tenant_id = get_tenant_id_for_email(user.email)
data = {
"sub": str(user.id),
"aud": self.token_audience,
"tenant_id": tenant_id,
}
return data
async def write_token(self, user: User) -> str:
data = await self._create_token_data(user)
return generate_jwt(
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
)
def get_jwt_strategy() -> TenantAwareJWTStrategy:
def get_jwt_strategy() -> JWTStrategy:
return TenantAwareJWTStrategy(
secret=USER_AUTH_SECRET,
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
@@ -605,7 +597,7 @@ optional_fastapi_current_user = fastapi_users.current_user(active=True, optional
async def optional_user_(
request: Request,
user: User | None,
async_db_session: AsyncSession,
db_session: Session,
) -> User | None:
"""NOTE: `request` and `db_session` are not used here, but are included
for the EE version of this function."""
@@ -614,21 +606,13 @@ async def optional_user_(
async def optional_user(
request: Request,
async_db_session: AsyncSession = Depends(get_async_session),
db_session: Session = Depends(get_session),
user: User | None = Depends(optional_fastapi_current_user),
) -> User | None:
versioned_fetch_user = fetch_versioned_implementation(
"danswer.auth.users", "optional_user_"
)
user = await versioned_fetch_user(request, user, async_db_session)
# check if an API key is present
if user is None:
hashed_api_key = get_hashed_api_key_from_request(request)
if hashed_api_key:
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
return user
return await versioned_fetch_user(request, user, db_session)
async def double_check_user(
@@ -640,12 +624,14 @@ async def double_check_user(
return None
if user is None:
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not authenticated.",
)
if user_needs_to_be_verified() and not user.is_verified:
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not verified.",
)
@@ -654,7 +640,8 @@ async def double_check_user(
and user.oidc_expiry < datetime.now(timezone.utc)
and not include_expired
):
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User's OIDC token has expired.",
)
@@ -667,24 +654,10 @@ async def current_user_with_expired_token(
return await double_check_user(user, include_expired=True)
async def current_limited_user(
user: User | None = Depends(optional_user),
) -> User | None:
return await double_check_user(user)
async def current_user(
user: User | None = Depends(optional_user),
) -> User | None:
user = await double_check_user(user)
if not user:
return None
if user.role == UserRole.LIMITED:
raise BasicAuthenticationError(
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
)
return user
return await double_check_user(user)
async def current_curator_or_admin_user(
@@ -694,13 +667,15 @@ async def current_curator_or_admin_user(
return None
if not user or not hasattr(user, "role"):
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not authenticated or lacks role information.",
)
allowed_roles = {UserRole.GLOBAL_CURATOR, UserRole.CURATOR, UserRole.ADMIN}
if user.role not in allowed_roles:
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not a curator or admin.",
)
@@ -712,7 +687,8 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
return None
if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN:
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User must be an admin to perform this action.",
)
@@ -740,6 +716,8 @@ def generate_state_token(
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
def create_danswer_oauth_router(
oauth_client: BaseOAuth2,
backend: AuthenticationBackend,
@@ -789,22 +767,15 @@ def get_oauth_router(
response_model=OAuth2AuthorizeResponse,
)
async def authorize(
request: Request,
scopes: List[str] = Query(None),
request: Request, scopes: List[str] = Query(None)
) -> OAuth2AuthorizeResponse:
referral_source = request.cookies.get("referral_source", None)
if redirect_url is not None:
authorize_redirect_url = redirect_url
else:
authorize_redirect_url = str(request.url_for(callback_route_name))
next_url = request.query_params.get("next", "/")
state_data: Dict[str, str] = {
"next_url": next_url,
"referral_source": referral_source or "default_referral",
}
state_data: Dict[str, str] = {"next_url": next_url}
state = generate_state_token(state_data, state_secret)
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url,
@@ -863,11 +834,8 @@ def get_oauth_router(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
next_url = state_data.get("next_url", "/")
referral_source = state_data.get("referral_source", None)
request.state.referral_source = referral_source
# Proceed to authenticate or create the user
# Authenticate user
try:
user = await user_manager.oauth_callback(
oauth_client.name,
@@ -909,25 +877,7 @@ def get_oauth_router(
redirect_response.status_code = response.status_code
if hasattr(response, "media_type"):
redirect_response.media_type = response.media_type
return redirect_response
return router
async def api_key_dep(
request: Request, async_db_session: AsyncSession = Depends(get_async_session)
) -> User | None:
if AUTH_TYPE == AuthType.DISABLED:
return None
hashed_api_key = get_hashed_api_key_from_request(request)
if not hashed_api_key:
raise HTTPException(status_code=401, detail="Missing API key")
if hashed_api_key:
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
if user is None:
raise HTTPException(status_code=401, detail="Invalid API key")
return user

View File

@@ -3,7 +3,6 @@ import multiprocessing
import time
from typing import Any
import requests
import sentry_sdk
from celery import Task
from celery.app import trace
@@ -11,26 +10,19 @@ from celery.exceptions import WorkerShutdown
from celery.states import READY_STATES
from celery.utils.log import get_task_logger
from celery.worker import strategy # type: ignore
from redis.lock import Lock as RedisLock
from sentry_sdk.integrations.celery import CeleryIntegration
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
from danswer.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.engine import get_sqlalchemy_engine
from danswer.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_document_set import RedisDocumentSet
from danswer.db.engine import get_all_tenant_ids
from danswer.redis.redis_pool import get_redis_client
from danswer.redis.redis_usergroup import RedisUserGroup
from danswer.utils.logger import ColoredFormatter
from danswer.utils.logger import PlainFormatter
from danswer.utils.logger import setup_logger
@@ -116,43 +108,29 @@ def on_task_postrun(
if task_id.startswith(RedisDocumentSet.PREFIX):
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
if document_set_id is not None:
rds = RedisDocumentSet(tenant_id, int(document_set_id))
rds = RedisDocumentSet(int(document_set_id))
r.srem(rds.taskset_key, task_id)
return
if task_id.startswith(RedisUserGroup.PREFIX):
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
if usergroup_id is not None:
rug = RedisUserGroup(tenant_id, int(usergroup_id))
rug = RedisUserGroup(int(usergroup_id))
r.srem(rug.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorDelete.PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if task_id.startswith(RedisConnectorDeletion.PREFIX):
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
if cc_pair_id is not None:
RedisConnectorDelete.remove_from_taskset(int(cc_pair_id), task_id, r)
rcd = RedisConnectorDeletion(int(cc_pair_id))
r.srem(rcd.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorPrune.SUBTASK_PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
if cc_pair_id is not None:
RedisConnectorPrune.remove_from_taskset(int(cc_pair_id), task_id, r)
return
if task_id.startswith(RedisConnectorPermissionSync.SUBTASK_PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if cc_pair_id is not None:
RedisConnectorPermissionSync.remove_from_taskset(
int(cc_pair_id), task_id, r
)
return
if task_id.startswith(RedisConnectorExternalGroupSync.SUBTASK_PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if cc_pair_id is not None:
RedisConnectorExternalGroupSync.remove_from_taskset(
int(cc_pair_id), task_id, r
)
rcp = RedisConnectorPruning(int(cc_pair_id))
r.srem(rcp.taskset_key, task_id)
return
@@ -162,154 +140,27 @@ def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
"""Waits for redis to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
r = get_redis_client(tenant_id=None)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Redis: Readiness probe starting.")
logger.info("Redis: Readiness check starting.")
while True:
try:
if r.ping():
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > WAIT_LIMIT:
break
logger.info(
f"Redis: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
time.sleep(WAIT_INTERVAL)
if not ready:
msg = (
f"Redis: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Redis: Readiness probe succeeded. Continuing...")
return
def wait_for_db(sender: Any, **kwargs: Any) -> None:
"""Waits for the db to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Database: Readiness probe starting.")
while True:
try:
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(text("SELECT NOW()")).scalar()
if result:
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > WAIT_LIMIT:
break
logger.info(
f"Database: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
time.sleep(WAIT_INTERVAL)
if not ready:
msg = (
f"Database: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Database: Readiness probe succeeded. Continuing...")
return
def wait_for_vespa(sender: Any, **kwargs: Any) -> None:
"""Waits for Vespa to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Vespa: Readiness probe starting.")
while True:
try:
response = requests.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
response.raise_for_status()
response_dict = response.json()
if response_dict["status"]["code"] == "up":
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > WAIT_LIMIT:
break
logger.info(
f"Vespa: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
time.sleep(WAIT_INTERVAL)
if not ready:
msg = (
f"Vespa: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Vespa: Readiness probe succeeded. Continuing...")
return
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("Running as a secondary celery worker.")
# Set up variables for waiting on primary worker
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
r = get_redis_client(tenant_id=None)
time_start = time.monotonic()
logger.info("Waiting for primary worker to be ready...")
while True:
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
break
time_elapsed = time.monotonic() - time_start
logger.info(
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Primary worker was not ready within the timeout. "
f"Redis: Readiness check did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
@@ -317,7 +168,57 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
time.sleep(WAIT_INTERVAL)
logger.info("Wait for primary worker completed successfully. Continuing...")
logger.info("Redis: Readiness check succeeded. Continuing...")
return
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
logger.info("Running as a secondary celery worker.")
logger.info("Waiting for all tenant primary workers to be ready...")
time_start = time.monotonic()
while True:
tenant_ids = get_all_tenant_ids()
# Check if we have a primary worker lock for each tenant
all_tenants_ready = all(
get_redis_client(tenant_id=tenant_id).exists(
DanswerRedisLocks.PRIMARY_WORKER
)
for tenant_id in tenant_ids
)
if all_tenants_ready:
break
time_elapsed = time.monotonic() - time_start
ready_tenants = sum(
1
for tenant_id in tenant_ids
if get_redis_client(tenant_id=tenant_id).exists(
DanswerRedisLocks.PRIMARY_WORKER
)
)
logger.info(
f"Not all tenant primary workers are ready yet. "
f"Ready tenants: {ready_tenants}/{len(tenant_ids)} "
f"elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Not all tenant primary workers were ready within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
time.sleep(WAIT_INTERVAL)
logger.info("All tenant primary workers are ready. Continuing...")
return
@@ -329,20 +230,26 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
if not celery_is_worker_primary(sender):
return
if not sender.primary_worker_lock:
if not hasattr(sender, "primary_worker_locks"):
return
logger.info("Releasing primary worker lock.")
lock: RedisLock = sender.primary_worker_lock
try:
if lock.owned():
try:
lock.release()
sender.primary_worker_lock = None
except Exception:
logger.exception("Failed to release primary worker lock")
except Exception:
logger.exception("Failed to check if primary worker lock is owned")
for tenant_id, lock in sender.primary_worker_locks.items():
try:
if lock and lock.owned():
logger.debug(f"Attempting to release lock for tenant {tenant_id}")
try:
lock.release()
logger.debug(f"Successfully released lock for tenant {tenant_id}")
except Exception as e:
logger.error(
f"Failed to release lock for tenant {tenant_id}. Error: {str(e)}"
)
finally:
sender.primary_worker_locks[tenant_id] = None
except Exception as e:
logger.error(
f"Error checking lock status for tenant {tenant_id}. Error: {str(e)}"
)
def on_setup_logging(

View File

@@ -3,162 +3,28 @@ from typing import Any
from celery import Celery
from celery import signals
from celery.beat import PersistentScheduler # type: ignore
from celery.signals import beat_init
import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
from shared_configs.configs import MULTI_TENANT
logger = setup_logger(__name__)
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("danswer.background.celery.configs.beat")
class DynamicTenantScheduler(PersistentScheduler):
def __init__(self, *args: Any, **kwargs: Any) -> None:
logger.info("Initializing DynamicTenantScheduler")
super().__init__(*args, **kwargs)
self._reload_interval = timedelta(minutes=2)
self._last_reload = self.app.now() - self._reload_interval
# Let the parent class handle store initialization
self.setup_schedule()
self._update_tenant_tasks()
logger.info(f"Set reload interval to {self._reload_interval}")
def setup_schedule(self) -> None:
logger.info("Setting up initial schedule")
super().setup_schedule()
logger.info("Initial schedule setup complete")
def tick(self) -> float:
retval = super().tick()
now = self.app.now()
if (
self._last_reload is None
or (now - self._last_reload) > self._reload_interval
):
logger.info("Reload interval reached, initiating tenant task update")
self._update_tenant_tasks()
self._last_reload = now
logger.info("Tenant task update completed, reset reload timer")
return retval
def _update_tenant_tasks(self) -> None:
logger.info("Starting tenant task update process")
try:
logger.info("Fetching all tenant IDs")
tenant_ids = get_all_tenant_ids()
logger.info(f"Found {len(tenant_ids)} tenants")
logger.info("Fetching tasks to schedule")
tasks_to_schedule = fetch_versioned_implementation(
"danswer.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
)
new_beat_schedule: dict[str, dict[str, Any]] = {}
current_schedule = self.schedule.items()
existing_tenants = set()
for task_name, _ in current_schedule:
if "-" in task_name:
existing_tenants.add(task_name.split("-")[-1])
logger.info(f"Found {len(existing_tenants)} existing tenants in schedule")
for tenant_id in tenant_ids:
if (
IGNORED_SYNCING_TENANT_LIST
and tenant_id in IGNORED_SYNCING_TENANT_LIST
):
logger.info(
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
)
continue
if tenant_id not in existing_tenants:
logger.info(f"Processing new tenant: {tenant_id}")
for task in tasks_to_schedule():
task_name = f"{task['name']}-{tenant_id}"
logger.debug(f"Creating task configuration for {task_name}")
new_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": {"tenant_id": tenant_id},
}
if options := task.get("options"):
logger.debug(f"Adding options to task {task_name}: {options}")
new_task["options"] = options
new_beat_schedule[task_name] = new_task
if self._should_update_schedule(current_schedule, new_beat_schedule):
logger.info(
"Schedule update required",
extra={
"new_tasks": len(new_beat_schedule),
"current_tasks": len(current_schedule),
},
)
# Create schedule entries
entries = {}
for name, entry in new_beat_schedule.items():
entries[name] = self.Entry(
name=name,
app=self.app,
task=entry["task"],
schedule=entry["schedule"],
options=entry.get("options", {}),
kwargs=entry.get("kwargs", {}),
)
# Update the schedule using the scheduler's methods
self.schedule.clear()
self.schedule.update(entries)
# Ensure changes are persisted
self.sync()
logger.info("Schedule update completed successfully")
else:
logger.info("Schedule is up to date, no changes needed")
except (AttributeError, KeyError):
logger.exception("Failed to process task configuration")
except Exception:
logger.exception("Unexpected error updating tenant tasks")
def _should_update_schedule(
self, current_schedule: dict, new_schedule: dict
) -> bool:
"""Compare schedules to determine if an update is needed."""
logger.debug("Comparing current and new schedules")
current_tasks = set(name for name, _ in current_schedule)
new_tasks = set(new_schedule.keys())
needs_update = current_tasks != new_tasks
logger.debug(f"Schedule update needed: {needs_update}")
return needs_update
@beat_init.connect
def on_beat_init(sender: Any, **kwargs: Any) -> None:
logger.info("beat_init signal received.")
# Celery beat shouldn't touch the db at all. But just setting a low minimum here.
# celery beat shouldn't touch the db at all. But just setting a low minimum here.
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
SqlEngine.init_engine(pool_size=2, max_overflow=0)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
@@ -169,4 +35,68 @@ def on_setup_logging(
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
celery_app.conf.beat_scheduler = DynamicTenantScheduler
#####
# Celery Beat (Periodic Tasks) Settings
#####
tenant_ids = get_all_tenant_ids()
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": "check_for_vespa_sync_task",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-connector-deletion",
"task": "check_for_connector_deletion_task",
"schedule": timedelta(seconds=60),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-indexing",
"task": "check_for_indexing",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": "check_for_pruning",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "kombu-message-cleanup",
"task": "kombu_message_cleanup_task",
"schedule": timedelta(seconds=3600),
"options": {"priority": DanswerCeleryPriority.LOWEST},
},
{
"name": "monitor-vespa-sync",
"task": "monitor_vespa_sync",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
]
# Build the celery beat schedule dynamically
beat_schedule = {}
for tenant_id in tenant_ids:
for task in tasks_to_schedule:
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
beat_schedule[task_name] = {
"task": task["task"],
"schedule": task["schedule"],
"options": task["options"],
"kwargs": {"tenant_id": tenant_id}, # Must pass tenant_id as an argument
}
# Include any existing beat schedules
existing_beat_schedule = celery_app.conf.beat_schedule or {}
beat_schedule.update(existing_beat_schedule)
# Update the Celery app configuration once
celery_app.conf.beat_schedule = beat_schedule

View File

@@ -13,7 +13,6 @@ import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -61,13 +60,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=4, max_overflow=12)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.on_secondary_worker_init(sender, **kwargs)
@@ -91,7 +84,5 @@ def on_setup_logging(
celery_app.autodiscover_tasks(
[
"danswer.background.celery.tasks.pruning",
"danswer.background.celery.tasks.doc_permission_syncing",
"danswer.background.celery.tasks.external_group_syncing",
]
)

View File

@@ -6,7 +6,6 @@ from celery import signals
from celery import Task
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
@@ -14,7 +13,6 @@ import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -60,15 +58,9 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
SqlEngine.init_engine(pool_size=8, max_overflow=0)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.on_secondary_worker_init(sender, **kwargs)
@@ -82,11 +74,6 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@worker_process_init.connect
def init_worker(**kwargs: Any) -> None:
SqlEngine.reset_engine()
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any

View File

@@ -13,7 +13,6 @@ import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -60,13 +59,8 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.on_secondary_worker_init(sender, **kwargs)
@@ -91,7 +85,5 @@ celery_app.autodiscover_tasks(
[
"danswer.background.celery.tasks.shared",
"danswer.background.celery.tasks.vespa",
"danswer.background.celery.tasks.connector_deletion",
"danswer.background.celery.tasks.doc_permission_syncing",
]
)

View File

@@ -1,6 +1,5 @@
import multiprocessing
from typing import Any
from typing import cast
from celery import bootsteps # type: ignore
from celery import Celery
@@ -11,33 +10,25 @@ from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
from redis.lock import Lock as RedisLock
import danswer.background.celery.apps.app_base as app_base
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.background.celery.tasks.indexing.tasks import (
get_unfenced_index_attempt_ids,
)
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import get_session_with_default_tenant
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import SqlEngine
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_connector_stop import RedisConnectorStop
from danswer.redis.redis_document_set import RedisDocumentSet
from danswer.redis.redis_pool import get_redis_client
from danswer.redis.redis_usergroup import RedisUserGroup
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -84,98 +75,95 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
logger.info("Running as the primary celery worker.")
sender.primary_worker_locks = {}
# This is singleton work that should be done on startup exactly once
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_redis_client(tenant_id=None)
# by the primary worker
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
r = get_redis_client(tenant_id=tenant_id)
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
info: dict[str, Any] = cast(dict, r.info("replication"))
role: str = cast(str, info.get("role"))
connected_slaves: int = info.get("connected_slaves", 0)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
logger.info(
f"Redis INFO REPLICATION: role={role} connected_slaves={connected_slaves}"
)
# this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")
# this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
# tacking on our own user data to the sender
sender.primary_worker_locks[tenant_id] = lock
# set thread_local=False since we don't control what thread the periodic task might
# reacquire the lock with
lock: RedisLock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
thread_local=False,
)
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
# tacking on our own user data to the sender
sender.primary_worker_lock = lock
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
r.delete(key)
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
r.delete(key)
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
r.delete(key)
RedisDocumentSet.reset_all(r)
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
r.delete(key)
RedisUserGroup.reset_all(r)
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
r.delete(key)
RedisConnectorDelete.reset_all(r)
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
r.delete(key)
RedisConnectorPrune.reset_all(r)
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
r.delete(key)
RedisConnectorIndex.reset_all(r)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
RedisConnectorStop.reset_all(r)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
RedisConnectorPermissionSync.reset_all(r)
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
r.delete(key)
RedisConnectorExternalGroupSync.reset_all(r)
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
r.delete(key)
# mark orphaned index attempts as failed
with get_session_with_default_tenant() as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
attempt = get_index_attempt(db_session, attempt_id)
if not attempt:
continue
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
failure_reason = (
f"Canceling leftover index attempt found on startup: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
logger.warning(failure_reason)
mark_attempt_canceled(attempt.id, db_session, failure_reason)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"):
r.delete(key)
@worker_ready.connect
@@ -228,36 +216,52 @@ class HubPeriodicTask(bootsteps.StartStopStep):
if not celery_is_worker_primary(worker):
return
if not hasattr(worker, "primary_worker_lock"):
if not hasattr(worker, "primary_worker_locks"):
return
lock: RedisLock = worker.primary_worker_lock
# Retrieve all tenant IDs
tenant_ids = get_all_tenant_ids()
r = get_redis_client(tenant_id=None)
for tenant_id in tenant_ids:
lock = worker.primary_worker_locks.get(tenant_id)
if not lock:
continue # Skip if no lock for this tenant
if lock.owned():
task_logger.debug("Reacquiring primary worker lock.")
lock.reacquire()
else:
task_logger.warning(
"Full acquisition of primary worker lock. "
"Reasons could be worker restart or lock expiration."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
r = get_redis_client(tenant_id=tenant_id)
task_logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info("Primary worker lock: Acquire succeeded.")
worker.primary_worker_lock = lock
if lock.owned():
task_logger.debug(
f"Reacquiring primary worker lock for tenant {tenant_id}."
)
lock.reacquire()
else:
task_logger.error("Primary worker lock: Acquire failed!")
raise TimeoutError("Primary worker lock could not be acquired!")
task_logger.warning(
f"Full acquisition of primary worker lock for tenant {tenant_id}. "
"Reasons could be worker restart or lock expiration."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire starting."
)
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire succeeded."
)
worker.primary_worker_locks[tenant_id] = lock
else:
task_logger.error(
f"Primary worker lock for tenant {tenant_id}: Acquire failed!"
)
raise TimeoutError(
f"Primary worker lock for tenant {tenant_id} could not be acquired!"
)
except Exception:
task_logger.exception("Periodic task failed.")
@@ -276,8 +280,6 @@ celery_app.autodiscover_tasks(
"danswer.background.celery.tasks.connector_deletion",
"danswer.background.celery.tasks.indexing",
"danswer.background.celery.tasks.periodic",
"danswer.background.celery.tasks.doc_permission_syncing",
"danswer.background.celery.tasks.external_group_syncing",
"danswer.background.celery.tasks.pruning",
"danswer.background.celery.tasks.shared",
"danswer.background.celery.tasks.vespa",

View File

@@ -1,10 +1,568 @@
# These are helper objects for tracking the keys we need to write in redis
import time
from abc import ABC
from abc import abstractmethod
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.configs.base import CELERY_SEPARATOR
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import construct_document_select_for_connector_credential_pair
from danswer.db.document import (
construct_document_select_for_connector_credential_pair_by_needs_sync,
)
from danswer.db.document_set import construct_document_select_by_docset
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import global_version
class RedisObjectHelper(ABC):
PREFIX = "base"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: str):
self._id: str = id
@property
def task_id_prefix(self) -> str:
return f"{self.PREFIX}_{self._id}"
@property
def fence_key(self) -> str:
# example: documentset_fence_1
return f"{self.FENCE_PREFIX}_{self._id}"
@property
def taskset_key(self) -> str:
# example: documentset_taskset_1
return f"{self.TASKSET_PREFIX}_{self._id}"
@staticmethod
def get_id_from_fence_key(key: str) -> str | None:
"""
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
Args:
key (str): The fence key string.
Returns:
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
"""
parts = key.split("_")
if len(parts) != 3:
return None
object_id = parts[2]
return object_id
@staticmethod
def get_id_from_task_id(task_id: str) -> str | None:
"""
Extracts the object ID from a task ID string.
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
- `objectid` is the ID you want to extract,
- `suffix` is another arbitrary string (e.g., a UUID).
Example:
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
this method will return the string `"1"`.
Args:
task_id (str): The task ID string from which to extract the object ID.
Returns:
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
"""
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
parts = task_id.split("_")
if len(parts) != 3:
return None
object_id = parts[1]
return object_id
@abstractmethod
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
pass
class RedisDocumentSet(RedisObjectHelper):
PREFIX = "documentset"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
stmt = construct_document_select_by_docset(int(self._id), current_only=False)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
async_results.append(result)
return len(async_results)
class RedisUserGroup(RedisObjectHelper):
PREFIX = "usergroup"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
if not global_version.is_ee_version():
return 0
try:
construct_document_select_by_usergroup = fetch_versioned_implementation(
"danswer.db.user_group",
"construct_document_select_by_usergroup",
)
except ModuleNotFoundError:
return 0
stmt = construct_document_select_by_usergroup(int(self._id))
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
async_results.append(result)
return len(async_results)
class RedisConnectorCredentialPair(RedisObjectHelper):
"""This class is used to scan documents by cc_pair in the db and collect them into
a unified set for syncing.
It differs from the other redis helpers in that the taskset used spans
all connectors and is not per connector."""
PREFIX = "connectorsync"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
@classmethod
def get_fence_key(cls) -> str:
return RedisConnectorCredentialPair.FENCE_PREFIX
@classmethod
def get_taskset_key(cls) -> str:
return RedisConnectorCredentialPair.TASKSET_PREFIX
@property
def taskset_key(self) -> str:
"""Notice that this is intentionally reusing the same taskset for all
connector syncs"""
# example: connector_taskset
return f"{self.TASKSET_PREFIX}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
class RedisConnectorDeletion(RedisObjectHelper):
PREFIX = "connectordeletion"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
"""Returns None if the cc_pair doesn't exist.
Otherwise, returns an int with the number of generated tasks."""
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
stmt = construct_document_select_for_connector_credential_pair(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(self.taskset_key, custom_task_id)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"document_by_cc_pair_cleanup_task",
kwargs=dict(
document_id=doc.id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
class RedisConnectorPruning(RedisObjectHelper):
"""Celery will kick off a long running generator task to crawl the connector and
find any missing docs, which will each then get a new cleanup task. The progress of
those tasks will then be monitored to completion.
Example rough happy path order:
Check connectorpruning_fence_1
Send generator task with id connectorpruning+generator_1_{uuid}
generator runs connector with callbacks that increment connectorpruning_generator_progress_1
generator creates many subtasks with id connectorpruning+sub_1_{uuid}
in taskset connectorpruning_taskset_1
on completion, generator sets connectorpruning_generator_complete_1
celery postrun removes subtasks from taskset
monitor beat task cleans up when taskset reaches 0 items
"""
PREFIX = "connectorpruning"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire pruning process
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
SUBTASK_PREFIX = PREFIX + "+sub"
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # a signal that contains generator progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # a signal that the generator has finished
def __init__(self, id: int) -> None:
super().__init__(str(id))
self.documents_to_prune: set[str] = set()
@property
def generator_task_id_prefix(self) -> str:
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
@property
def generator_progress_key(self) -> str:
# example: connectorpruning_generator_progress_1
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
@property
def generator_complete_key(self) -> str:
# example: connectorpruning_generator_complete_1
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
@property
def subtask_id_prefix(self) -> str:
return f"{self.SUBTASK_PREFIX}_{self._id}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
for doc_id in self.documents_to_prune:
current_time = time.monotonic()
if lock and current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.subtask_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(self.taskset_key, custom_task_id)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"document_by_cc_pair_cleanup_task",
kwargs=dict(
document_id=doc_id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
def is_pruning(self, redis_client: Redis) -> bool:
"""A single example of a helper method being refactored into the redis helper"""
if redis_client.exists(self.fence_key):
return True
return False
class RedisConnectorIndexing(RedisObjectHelper):
"""Celery will kick off a long running indexing task to crawl the connector and
find any new or updated docs docs, which will each then get a new sync task or be
indexed inline.
ID should be a concatenation of cc_pair_id and search_setting_id, delimited by "/".
e.g. "2/5"
"""
PREFIX = "connectorindexing"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
SUBTASK_PREFIX = PREFIX + "+sub"
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # a signal that contains generator progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # a signal that the generator has finished
def __init__(self, cc_pair_id: int, search_settings_id: int) -> None:
super().__init__(f"{cc_pair_id}/{search_settings_id}")
@property
def generator_lock_key(self) -> str:
return f"{self.GENERATOR_LOCK_PREFIX}_{self._id}"
@property
def generator_task_id_prefix(self) -> str:
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
@property
def generator_progress_key(self) -> str:
# example: connectorpruning_generator_progress_1
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
@property
def generator_complete_key(self) -> str:
# example: connectorpruning_generator_complete_1
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
@property
def subtask_id_prefix(self) -> str:
return f"{self.SUBTASK_PREFIX}_{self._id}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
return None
def is_indexing(self, redis_client: Redis) -> bool:
"""A single example of a helper method being refactored into the redis helper"""
if redis_client.exists(self.fence_key):
return True
return False
class RedisConnectorStop(RedisObjectHelper):
"""Used to signal any running tasks for a connector to stop. We should refactor
connector related redis helpers into a single class.
"""
PREFIX = "connectorstop"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
return None
def celery_get_queue_length(queue: str, r: Redis) -> int:

View File

@@ -4,6 +4,8 @@ from typing import Any
from sqlalchemy.orm import Session
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
@@ -16,8 +18,7 @@ from danswer.connectors.models import Document
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.enums import TaskStatus
from danswer.db.models import TaskQueueState
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
from danswer.server.documents.models import DeletionAttemptSnapshot
from danswer.utils.logger import setup_logger
@@ -40,14 +41,14 @@ def _get_deletion_status(
if not cc_pair:
return None
redis_connector = RedisConnector(tenant_id, cc_pair.id)
if not redis_connector.delete.fenced:
rcd = RedisConnectorDeletion(cc_pair.id)
r = get_redis_client(tenant_id=tenant_id)
if not r.exists(rcd.fence_key):
return None
return TaskQueueState(
task_id="",
task_name=redis_connector.delete.fence_key,
status=TaskStatus.STARTED,
task_id="", task_name=rcd.fence_key, status=TaskStatus.STARTED
)
@@ -78,10 +79,10 @@ def document_batch_to_ids(
def extract_ids_from_runnable_connector(
runnable_connector: BaseConnector,
callback: IndexingHeartbeatInterface | None = None,
callback: RunIndexingCallbackInterface | None = None,
) -> set[str]:
"""
If the SlimConnector hasnt been implemented for the given connector, just pull
If the PruneConnector hasnt been implemented for the given connector, just pull
all docs using the load_from_state and grab out the IDs.
Optionally, a callback can be passed to handle the length of each document batch.
@@ -111,15 +112,10 @@ def extract_ids_from_runnable_connector(
for doc_batch in doc_batch_generator:
if callback:
if callback.should_stop():
raise RuntimeError(
"extract_ids_from_runnable_connector: Stop signal detected"
)
raise RuntimeError("Stop signal received")
callback.progress(len(doc_batch))
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
if callback:
callback.progress("extract_ids_from_runnable_connector", len(doc_batch))
return all_connector_doc_ids

View File

@@ -1,61 +0,0 @@
from datetime import timedelta
from typing import Any
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryTask
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-connector-deletion",
"task": DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-indexing",
"task": DanswerCeleryTask.CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": DanswerCeleryTask.CHECK_FOR_PRUNING,
"schedule": timedelta(seconds=15),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "kombu-message-cleanup",
"task": DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
"schedule": timedelta(seconds=3600),
"options": {"priority": DanswerCeleryPriority.LOWEST},
},
{
"name": "monitor-vespa-sync",
"task": DanswerCeleryTask.MONITOR_VESPA_SYNC,
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-doc-permissions-sync",
"task": DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
"schedule": timedelta(seconds=30),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-external-group-sync",
"task": DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
]
def get_tasks_to_schedule() -> list[dict[str, Any]]:
return tasks_to_schedule

View File

@@ -1,25 +1,30 @@
from datetime import datetime
from datetime import timezone
import redis
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis.lock import Lock as RedisLock
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
RedisConnectorDeletionFenceData,
)
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.search_settings import get_all_search_settings
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_delete import RedisConnectorDeletePayload
from danswer.redis.redis_pool import get_redis_client
@@ -29,7 +34,7 @@ class TaskDependencyError(RuntimeError):
@shared_task(
name=DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
name="check_for_connector_deletion_task",
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
@@ -37,7 +42,7 @@ class TaskDependencyError(RuntimeError):
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -57,19 +62,19 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
# try running cleanup on the cc_pair_ids
for cc_pair_id in cc_pair_ids:
with get_session_with_tenant(tenant_id) as db_session:
redis_connector = RedisConnector(tenant_id, cc_pair_id)
rcs = RedisConnectorStop(cc_pair_id)
try:
try_generate_document_cc_pair_cleanup_tasks(
self.app, cc_pair_id, db_session, lock_beat, tenant_id
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
)
except TaskDependencyError as e:
# this means we wanted to start deleting but dependent tasks were running
# Leave a stop signal to clear indexing and pruning tasks more quickly
task_logger.info(str(e))
redis_connector.stop.set_fence(True)
r.set(rcs.fence_key, cc_pair_id)
else:
# clear the stop signal if it exists ... no longer needed
redis_connector.stop.set_fence(False)
r.delete(rcs.fence_key)
except SoftTimeLimitExceeded:
task_logger.info(
@@ -86,7 +91,8 @@ def try_generate_document_cc_pair_cleanup_tasks(
app: Celery,
cc_pair_id: int,
db_session: Session,
lock_beat: RedisLock,
r: Redis,
lock_beat: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
@@ -100,10 +106,10 @@ def try_generate_document_cc_pair_cleanup_tasks(
lock_beat.reacquire()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
rcd = RedisConnectorDeletion(cc_pair_id)
# don't generate sync tasks if tasks are still pending
if redis_connector.delete.fenced:
if r.exists(rcd.fence_key):
return None
# we need to load the state of the object inside the fence
@@ -117,55 +123,47 @@ def try_generate_document_cc_pair_cleanup_tasks(
return None
# set a basic fence to start
fence_payload = RedisConnectorDeletePayload(
fence_value = RedisConnectorDeletionFenceData(
num_tasks=None,
submitted=datetime.now(timezone.utc),
)
redis_connector.delete.set_fence(fence_payload)
r.set(rcd.fence_key, fence_value.model_dump_json())
try:
# do not proceed if connector indexing or connector pruning are running
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(search_settings.id)
if redis_connector_index.fenced:
rci = RedisConnectorIndexing(cc_pair_id, search_settings.id)
if r.get(rci.fence_key):
raise TaskDependencyError(
f"Connector deletion - Delayed (indexing in progress): "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings.id}"
)
if redis_connector.prune.fenced:
rcp = RedisConnectorPruning(cc_pair_id)
if r.get(rcp.fence_key):
raise TaskDependencyError(
f"Connector deletion - Delayed (pruning in progress): "
f"cc_pair={cc_pair_id}"
)
if redis_connector.permissions.fenced:
raise TaskDependencyError(
f"Connector deletion - Delayed (permissions in progress): "
f"cc_pair={cc_pair_id}"
)
# add tasks to celery and build up the task set to monitor in redis
redis_connector.delete.taskset_clear()
r.delete(rcd.taskset_key)
# Add all documents that need to be updated into the queue
task_logger.info(
f"RedisConnectorDeletion.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = redis_connector.delete.generate_tasks(
app, db_session, lock_beat
)
tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id)
if tasks_generated is None:
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
except TaskDependencyError:
redis_connector.delete.set_fence(None)
r.delete(rcd.fence_key)
raise
except Exception:
task_logger.exception("Unexpected exception")
redis_connector.delete.set_fence(None)
r.delete(rcd.fence_key)
return None
else:
# Currently we are allowing the sync to proceed with 0 tasks.
@@ -180,7 +178,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
)
# set this only after all tasks have been added
fence_payload.num_tasks = tasks_generated
redis_connector.delete.set_fence(fence_payload)
fence_value.num_tasks = tasks_generated
r.set(rcd.fence_key, fence_value.model_dump_json())
return tasks_generated

View File

@@ -1,345 +0,0 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from uuid import uuid4
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from danswer.access.models import DocExternalAccess
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import DocumentSource
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import upsert_document_by_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_ext_perm_user_if_not_exists
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncPayload,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import doc_permission_sync_ctx
from danswer.utils.logger import setup_logger
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.danswer.db.document import upsert_document_external_perms
from ee.danswer.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
logger = setup_logger()
DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES = 3
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
LIGHT_SOFT_TIME_LIMIT = 105
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if external doc permissions sync is due."""
if cc_pair.access_type != AccessType.SYNC:
return False
# skip doc permissions sync if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return False
# If the last sync is None, it has never been run so we run the sync
last_perm_sync = cc_pair.last_time_perm_sync
if last_perm_sync is None:
return True
source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
if not source_sync_period:
return True
# If the last sync is greater than the full fetch period, we run the sync
next_sync = last_perm_sync + timedelta(seconds=source_sync_period)
if datetime.now(timezone.utc) >= next_sync:
return True
return False
@shared_task(
name=DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
# get all cc pairs that need to be synced
cc_pair_ids_to_sync: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
if _is_external_doc_permissions_sync_due(cc_pair):
cc_pair_ids_to_sync.append(cc_pair.id)
for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_permissions_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not tasks_created:
continue
task_logger.info(f"Doc permissions sync queued: cc_pair={cc_pair_id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
def try_creating_permissions_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
tenant_id: str | None,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Returns None if no syncing is required."""
redis_connector = RedisConnector(tenant_id, cc_pair_id)
LOCK_TIMEOUT = 30
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
try:
if redis_connector.permissions.fenced:
return None
if redis_connector.delete.fenced:
return None
if redis_connector.prune.fenced:
return None
redis_connector.permissions.generator_clear()
redis_connector.permissions.taskset_clear()
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
result = app.send_task(
DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
kwargs=dict(
cc_pair_id=cc_pair_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.HIGH,
)
# set a basic fence to start
payload = RedisConnectorPermissionSyncPayload(
started=None, celery_task_id=result.id
)
redis_connector.permissions.set_fence(payload)
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
return None
finally:
if lock.owned():
lock.release()
return 1
@shared_task(
name=DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,
trail=False,
bind=True,
)
def connector_permission_sync_generator_task(
self: Task,
cc_pair_id: int,
tenant_id: str | None,
) -> None:
"""
Permission sync task that handles document permission syncing for a given connector credential pair
This task assumes that the task has already been properly fenced
"""
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
doc_permission_sync_ctx_dict["cc_pair_id"] = cc_pair_id
doc_permission_sync_ctx_dict["request_id"] = self.request.id
doc_permission_sync_ctx.set(doc_permission_sync_ctx_dict)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
lock = r.lock(
DanswerRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"Permission sync task already running, exiting...: cc_pair={cc_pair_id}"
)
return None
try:
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if cc_pair is None:
raise ValueError(
f"No connector credential pair found for id: {cc_pair_id}"
)
source_type = cc_pair.connector.source
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
if doc_sync_func is None:
raise ValueError(
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
)
logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}")
payload = redis_connector.permissions.payload
if not payload:
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
payload.started = datetime.now(timezone.utc)
redis_connector.permissions.set_fence(payload)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
task_logger.info(
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = redis_connector.permissions.generate_tasks(
celery_app=self.app,
lock=lock,
new_permissions=document_external_accesses,
source_string=source_type,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
if tasks_generated is None:
return None
task_logger.info(
f"RedisConnector.permissions.generate_tasks finished. "
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
)
redis_connector.permissions.generator_complete = tasks_generated
except Exception as e:
task_logger.exception(f"Failed to run permission sync: cc_pair={cc_pair_id}")
redis_connector.permissions.generator_clear()
redis_connector.permissions.taskset_clear()
redis_connector.permissions.set_fence(None)
raise e
finally:
if lock.owned():
lock.release()
@shared_task(
name=DanswerCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,
max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
bind=True,
)
def update_external_document_permissions_task(
self: Task,
tenant_id: str | None,
serialized_doc_external_access: dict,
source_string: str,
connector_id: int,
credential_id: int,
) -> bool:
document_external_access = DocExternalAccess.from_dict(
serialized_doc_external_access
)
doc_id = document_external_access.doc_id
external_access = document_external_access.external_access
try:
with get_session_with_tenant(tenant_id) as db_session:
# Add the users to the DB if they don't exist
batch_add_ext_perm_user_if_not_exists(
db_session=db_session,
emails=list(external_access.external_user_emails),
)
# Then we upsert the document's external permissions in postgres
created_new_doc = upsert_document_external_perms(
db_session=db_session,
doc_id=doc_id,
external_access=external_access,
source_type=DocumentSource(source_string),
)
if created_new_doc:
# If a new document was created, we associate it with the cc_pair
upsert_document_by_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
document_ids=[doc_id],
)
logger.debug(
f"Successfully synced postgres document permissions for {doc_id}"
)
return True
except Exception:
logger.exception("Error Syncing Document Permissions")
return False

View File

@@ -1,298 +0,0 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from uuid import uuid4
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector import mark_cc_pair_as_external_group_synced
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_ext_group_sync import (
RedisConnectorExternalGroupSyncPayload,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.danswer.db.connector_credential_pair import get_cc_pairs_by_source
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIODS
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
from ee.danswer.external_permissions.sync_params import (
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
)
logger = setup_logger()
EXTERNAL_GROUPS_UPDATE_MAX_RETRIES = 3
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
LIGHT_SOFT_TIME_LIMIT = 105
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if external group sync is due."""
if cc_pair.access_type != AccessType.SYNC:
return False
# skip external group sync if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return False
# If there is not group sync function for the connector, we don't run the sync
# This is fine because all sources dont necessarily have a concept of groups
if not GROUP_PERMISSIONS_FUNC_MAP.get(cc_pair.connector.source):
return False
# If the last sync is None, it has never been run so we run the sync
last_ext_group_sync = cc_pair.last_time_external_group_sync
if last_ext_group_sync is None:
return True
source_sync_period = EXTERNAL_GROUP_SYNC_PERIODS.get(cc_pair.connector.source)
# If EXTERNAL_GROUP_SYNC_PERIODS is None, we always run the sync.
if not source_sync_period:
return True
# If the last sync is greater than the full fetch period, we run the sync
next_sync = last_ext_group_sync + timedelta(seconds=source_sync_period)
if datetime.now(timezone.utc) >= next_sync:
return True
return False
@shared_task(
name=DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
cc_pair_ids_to_sync: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
# We only want to sync one cc_pair per source type in
# GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC
for source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC:
# These are ordered by cc_pair id so the first one is the one we want
cc_pairs_to_dedupe = get_cc_pairs_by_source(
db_session, source, only_sync=True
)
# We only want to sync one cc_pair per source type
# in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC so we dedupe here
for cc_pair_to_remove in cc_pairs_to_dedupe[1:]:
cc_pairs = [
cc_pair
for cc_pair in cc_pairs
if cc_pair.id != cc_pair_to_remove.id
]
for cc_pair in cc_pairs:
if _is_external_group_sync_due(cc_pair):
cc_pair_ids_to_sync.append(cc_pair.id)
for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_external_group_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not tasks_created:
continue
task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
def try_creating_external_group_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
tenant_id: str | None,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Returns None if no syncing is required."""
redis_connector = RedisConnector(tenant_id, cc_pair_id)
LOCK_TIMEOUT = 30
lock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_external_group_sync_tasks",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
try:
# Dont kick off a new sync if the previous one is still running
if redis_connector.external_group_sync.fenced:
return None
redis_connector.external_group_sync.generator_clear()
redis_connector.external_group_sync.taskset_clear()
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
result = app.send_task(
DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK,
kwargs=dict(
cc_pair_id=cc_pair_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.HIGH,
)
payload = RedisConnectorExternalGroupSyncPayload(
started=datetime.now(timezone.utc),
celery_task_id=result.id,
)
redis_connector.external_group_sync.set_fence(payload)
except Exception:
task_logger.exception(
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
)
return None
finally:
if lock.owned():
lock.release()
return 1
@shared_task(
name=DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK,
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,
trail=False,
bind=True,
)
def connector_external_group_sync_generator_task(
self: Task,
cc_pair_id: int,
tenant_id: str | None,
) -> None:
"""
Permission sync task that handles external group syncing for a given connector credential pair
This task assumes that the task has already been properly fenced
"""
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = r.lock(
DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
)
try:
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
)
return None
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if cc_pair is None:
raise ValueError(
f"No connector credential pair found for id: {cc_pair_id}"
)
source_type = cc_pair.connector.source
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
if ext_group_sync_func is None:
raise ValueError(
f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
)
logger.info(
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
)
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
logger.info(
f"Syncing {len(external_user_groups)} external user groups for {source_type}"
)
replace_user__ext_group_for_cc_pair(
db_session=db_session,
cc_pair_id=cc_pair.id,
group_defs=external_user_groups,
source=cc_pair.connector.source,
)
logger.info(
f"Synced {len(external_user_groups)} external user groups for {source_type}"
)
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
except Exception as e:
task_logger.exception(
f"Failed to run external group sync: cc_pair={cc_pair_id}"
)
redis_connector.external_group_sync.generator_clear()
redis_connector.external_group_sync.taskset_clear()
raise e
finally:
# we always want to clear the fence after the task is done or failed so it doesn't get stuck
redis_connector.external_group_sync.set_fence(None)
if lock.owned():
lock.release()

View File

@@ -2,171 +2,99 @@ from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from time import sleep
from typing import cast
from uuid import uuid4
import redis
import sentry_sdk
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.exceptions import LockError
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
RedisConnectorIndexingFenceData,
)
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import DocumentSource
from danswer.db.connector import mark_ccpair_with_indexing_trigger
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import IndexingMode
from danswer.db.enums import IndexingStatus
from danswer.db.enums import IndexModelStatus
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import delete_index_attempt
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_active_search_settings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_index import RedisConnectorIndexPayload
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
class IndexingCallback(IndexingHeartbeatInterface):
class RunIndexingCallback(RunIndexingCallbackInterface):
def __init__(
self,
stop_key: str,
generator_progress_key: str,
redis_lock: RedisLock,
redis_lock: redis.lock.Lock,
redis_client: Redis,
):
super().__init__()
self.redis_lock: RedisLock = redis_lock
self.redis_lock: redis.lock.Lock = redis_lock
self.stop_key: str = stop_key
self.generator_progress_key: str = generator_progress_key
self.redis_client = redis_client
self.started: datetime = datetime.now(timezone.utc)
self.redis_lock.reacquire()
self.last_tag: str = "IndexingCallback.__init__"
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
def should_stop(self) -> bool:
if self.redis_client.exists(self.stop_key):
return True
return False
def progress(self, tag: str, amount: int) -> None:
try:
self.redis_lock.reacquire()
self.last_tag = tag
self.last_lock_reacquire = datetime.now(timezone.utc)
except LockError:
logger.exception(
f"IndexingCallback - lock.reacquire exceptioned. "
f"lock_timeout={self.redis_lock.timeout} "
f"start={self.started} "
f"last_tag={self.last_tag} "
f"last_reacquired={self.last_lock_reacquire} "
f"now={datetime.now(timezone.utc)}"
)
raise
def progress(self, amount: int) -> None:
self.redis_lock.reacquire()
self.redis_client.incrby(self.generator_progress_key, amount)
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
want to clean them up.
Unfenced = attempt not in terminal state and fence does not exist.
"""
unfenced_attempts: list[int] = []
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# inner = index_attempt in non terminal state
# outer = r.fence_key down
# check the db for index attempts in a non terminal state
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for attempt in attempts:
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
)
# if the fence is down / doesn't exist, possible error but not confirmed
if r.exists(fence_key):
continue
# Between the time the attempts are first looked up and the time we see the fence down,
# the attempt may have completed and taken down the fence normally.
# We need to double check that the index attempt is still in a non terminal state
# and matches the original state, which confirms we are really in a bad state.
attempt_2 = get_index_attempt(db_session, attempt.id)
if not attempt_2:
continue
if attempt.status != attempt_2.status:
continue
unfenced_attempts.append(attempt.id)
return unfenced_attempts
@shared_task(
name=DanswerCeleryTask.CHECK_FOR_INDEXING,
name="check_for_indexing",
soft_time_limit=300,
bind=True,
)
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
tasks_created = 0
locked = False
r = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
lock_beat = r.lock(
DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -176,49 +104,44 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
if not lock_beat.acquire(blocking=False):
return None
locked = True
# check for search settings swap
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
old_search_settings = check_index_swap(db_session=db_session)
check_index_swap(db_session=db_session)
current_search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
if current_search_settings.provider_type is None and not MULTI_TENANT:
if old_search_settings:
embedding_model = EmbeddingModel.from_db_model(
search_settings=current_search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
)
embedding_model = EmbeddingModel.from_db_model(
search_settings=current_search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
)
warm_up_bi_encoder(
embedding_model=embedding_model,
)
# only warm up if search settings were changed
warm_up_bi_encoder(
embedding_model=embedding_model,
)
# gather cc_pair_ids
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
lock_beat.reacquire()
cc_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
# kick off index attempts
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
search_settings_list: list[SearchSettings] = get_active_search_settings(
db_session
)
for search_settings_instance in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings = [primary_search_settings]
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings.append(secondary_search_settings)
for search_settings_instance in search_settings:
rci = RedisConnectorIndexing(
cc_pair_id, search_settings_instance.id
)
if redis_connector_index.fenced:
if r.exists(rci.fence_key):
continue
cc_pair = get_connector_credential_pair_from_id(
@@ -230,80 +153,31 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, search_settings_instance.id, db_session
)
search_settings_primary = False
if search_settings_instance.id == search_settings_list[0].id:
search_settings_primary = True
if not _should_index(
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
search_settings_primary=search_settings_primary,
secondary_index_building=len(search_settings_list) > 1,
secondary_index_building=len(search_settings) > 1,
db_session=db_session,
):
continue
reindex = False
if search_settings_instance.id == search_settings_list[0].id:
# the indexing trigger is only checked and cleared with the primary search settings
if cc_pair.indexing_trigger is not None:
if cc_pair.indexing_trigger == IndexingMode.REINDEX:
reindex = True
task_logger.info(
f"Connector indexing manual trigger detected: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id} "
f"indexing_mode={cc_pair.indexing_trigger}"
)
mark_ccpair_with_indexing_trigger(
cc_pair.id, None, db_session
)
# using a task queue and only allowing one task per cc_pair/search_setting
# prevents us from starving out certain attempts
attempt_id = try_creating_indexing_task(
self.app,
cc_pair,
search_settings_instance,
reindex,
False,
db_session,
r,
tenant_id,
)
if attempt_id:
task_logger.info(
f"Connector indexing queued: "
f"index_attempt={attempt_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id}"
f"Indexing queued: cc_pair={cc_pair.id} index_attempt={attempt_id}"
)
tasks_created += 1
# Fail any index attempts in the DB that don't have fences
# This shouldn't ever happen!
with get_session_with_tenant(tenant_id) as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
lock_beat.reacquire()
attempt = get_index_attempt(db_session, attempt_id)
if not attempt:
continue
failure_reason = (
f"Unfenced index attempt found in DB: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
task_logger.error(failure_reason)
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -311,14 +185,8 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if locked:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"check_for_indexing - Lock not owned on completion: "
f"tenant={tenant_id}"
)
if lock_beat.owned():
lock_beat.release()
return tasks_created
@@ -327,7 +195,6 @@ def _should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
search_settings_primary: bool,
secondary_index_building: bool,
db_session: Session,
) -> bool:
@@ -392,11 +259,6 @@ def _should_index(
):
return False
if search_settings_primary:
if cc_pair.indexing_trigger is not None:
# if a manual indexing trigger is on the cc pair, honor it for primary search settings
return True
# if no attempt has ever occurred, we should index regardless of refresh_freq
if not last_index:
return True
@@ -429,11 +291,10 @@ def try_creating_indexing_task(
"""
LOCK_TIMEOUT = 30
index_attempt_id: int | None = None
# we need to serialize any attempt to trigger indexing since it can be triggered
# either via celery beat or manually (API call)
lock: RedisLock = r.lock(
lock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
timeout=LOCK_TIMEOUT,
)
@@ -443,15 +304,15 @@ def try_creating_indexing_task(
return None
try:
redis_connector = RedisConnector(tenant_id, cc_pair.id)
redis_connector_index = redis_connector.new_index(search_settings.id)
rci = RedisConnectorIndexing(cc_pair.id, search_settings.id)
# skip if already indexing
if redis_connector_index.fenced:
if r.exists(rci.fence_key):
return None
# skip indexing if the cc_pair is deleting
if redis_connector.delete.fenced:
rcd = RedisConnectorDeletion(cc_pair.id)
if r.exists(rcd.fence_key):
return None
db_session.refresh(cc_pair)
@@ -459,17 +320,19 @@ def try_creating_indexing_task(
return None
# add a long running generator task to the queue
redis_connector_index.generator_clear()
r.delete(rci.generator_complete_key)
r.delete(rci.taskset_key)
custom_task_id = f"{rci.generator_task_id_prefix}_{uuid4()}"
# set a basic fence to start
payload = RedisConnectorIndexPayload(
fence_value = RedisConnectorIndexingFenceData(
index_attempt_id=None,
started=None,
submitted=datetime.now(timezone.utc),
celery_task_id=None,
)
redis_connector_index.set_fence(payload)
r.set(rci.fence_key, fence_value.model_dump_json())
# create the index attempt for tracking purposes
# code elsewhere checks for index attempts without an associated redis key
@@ -482,12 +345,8 @@ def try_creating_indexing_task(
db_session=db_session,
)
custom_task_id = redis_connector_index.generate_generator_task_id()
# when the task is sent, we have yet to finish setting up the fence
# therefore, the task must contain code that blocks until the fence is ready
result = celery_app.send_task(
DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
"connector_indexing_proxy_task",
kwargs=dict(
index_attempt_id=index_attempt_id,
cc_pair_id=cc_pair.id,
@@ -502,20 +361,17 @@ def try_creating_indexing_task(
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
# now fill out the fence with the rest of the data
payload.index_attempt_id = index_attempt_id
payload.celery_task_id = result.id
redis_connector_index.set_fence(payload)
fence_value.index_attempt_id = index_attempt_id
fence_value.celery_task_id = result.id
r.set(rci.fence_key, fence_value.model_dump_json())
except Exception:
r.delete(rci.fence_key)
task_logger.exception(
f"try_creating_indexing_task - Unexpected exception: "
f"Unexpected exception: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id}"
)
if index_attempt_id is not None:
delete_index_attempt(db_session, index_attempt_id)
redis_connector_index.set_fence(None)
return None
finally:
if lock.owned():
@@ -524,34 +380,19 @@ def try_creating_indexing_task(
return index_attempt_id
@shared_task(
name=DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
bind=True,
acks_late=False,
track_started=True,
)
@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True)
def connector_indexing_proxy_task(
self: Task,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
) -> None:
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
task_logger.info(
f"Indexing watchdog - starting: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
if not self.request.id:
task_logger.error("self.request.id is None!")
client = SimpleJobClient()
job = client.submit(
connector_indexing_task_wrapper,
connector_indexing_task,
index_attempt_id,
cc_pair_id,
search_settings_id,
@@ -561,138 +402,32 @@ def connector_indexing_proxy_task(
)
if not job:
task_logger.info(
f"Indexing watchdog - spawn failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return
task_logger.info(
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
while True:
sleep(5)
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
"Indexing watchdog - termination signal detected: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
sleep(10)
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
try:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_canceled(
index_attempt_id,
db_session,
"Connector termination signal detected",
)
finally:
# if the DB exceptions, we'll just get an unfriendly failure message
# in the UI instead of the cancellation message
logger.exception(
"Indexing watchdog - transient exception marking index attempt as canceled: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
# do nothing for ongoing jobs that haven't been stopped
if not job.done():
if not index_attempt:
continue
job.cancel()
if not index_attempt.is_finished():
continue
if job.status == "error":
logger.error(job.exception())
job.release()
break
if not job.done():
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
try:
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
except Exception:
# if the DB exceptioned, just restart the check.
# polling the index attempt status doesn't need to be strongly consistent
logger.exception(
"Indexing watchdog - transient exception looking up index attempt: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
continue
if job.status == "error":
task_logger.error(
"Indexing watchdog - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"error={job.exception()}"
)
job.release()
break
task_logger.info(
f"Indexing watchdog - finished: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return
def connector_indexing_task_wrapper(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
is_ee: bool,
) -> int | None:
"""Just wraps connector_indexing_task so we can log any exceptions before
re-raising it."""
result: int | None = None
try:
result = connector_indexing_task(
index_attempt_id,
cc_pair_id,
search_settings_id,
tenant_id,
is_ee,
)
except:
logger.exception(
f"connector_indexing_task exceptioned: "
f"tenant={tenant_id} "
f"index_attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
raise
return result
def connector_indexing_task(
index_attempt_id: int,
cc_pair_id: int,
@@ -711,102 +446,78 @@ def connector_indexing_task(
Returns None if the task did not run (possibly due to a conflict).
Otherwise, returns an int >= 0 representing the number of indexed docs.
NOTE: if an exception is raised out of this task, the primary worker will detect
that the task transitioned to a "READY" state but the generator_complete_key doesn't exist.
This will cause the primary worker to abort the indexing attempt and clean up.
"""
# Since connector_indexing_proxy_task spawns a new process using this function as
# the entrypoint, we init Sentry here.
if SENTRY_DSN:
sentry_sdk.init(
dsn=SENTRY_DSN,
traces_sample_rate=0.1,
)
logger.info("Sentry initialized")
else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
logger.info(
f"Indexing spawned task starting: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
attempt_found = False
n_final_progress: int | None = None
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
attempt = None
n_final_progress = 0
r = get_redis_client(tenant_id=tenant_id)
if redis_connector.delete.fenced:
rcd = RedisConnectorDeletion(cc_pair_id)
if r.exists(rcd.fence_key):
raise RuntimeError(
f"Indexing will not start because connector deletion is in progress: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.delete.fence_key}"
f"fence={rcd.fence_key}"
)
if redis_connector.stop.fenced:
rcs = RedisConnectorStop(cc_pair_id)
if r.exists(rcs.fence_key):
raise RuntimeError(
f"Indexing will not start because a connector stop signal was detected: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.stop.fence_key}"
f"fence={rcs.fence_key}"
)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
while True:
if not redis_connector_index.fenced: # The fence must exist
# read related data and evaluate/print task progress
fence_value = cast(bytes, r.get(rci.fence_key))
if fence_value is None:
raise ValueError(
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
f"connector_indexing_task: fence_value not found: fence={rci.fence_key}"
)
payload = redis_connector_index.payload # The payload must exist
if not payload:
raise ValueError("connector_indexing_task: payload invalid or not found")
try:
fence_json = fence_value.decode("utf-8")
fence_data = RedisConnectorIndexingFenceData.model_validate_json(
cast(str, fence_json)
)
except ValueError:
task_logger.exception(
f"connector_indexing_task: fence_data not decodeable: fence={rci.fence_key}"
)
raise
if payload.index_attempt_id is None or payload.celery_task_id is None:
logger.info(
f"connector_indexing_task - Waiting for fence: fence={redis_connector_index.fence_key}"
if fence_data.index_attempt_id is None or fence_data.celery_task_id is None:
task_logger.info(
f"connector_indexing_task - Waiting for fence: fence={rci.fence_key}"
)
sleep(1)
continue
if payload.index_attempt_id != index_attempt_id:
raise ValueError(
f"connector_indexing_task - id mismatch. Task may be left over from previous run.: "
f"task_index_attempt={index_attempt_id} "
f"payload_index_attempt={payload.index_attempt_id}"
)
logger.info(
f"connector_indexing_task - Fence found, continuing...: fence={redis_connector_index.fence_key}"
task_logger.info(
f"connector_indexing_task - Fence found, continuing...: fence={rci.fence_key}"
)
break
# set thread_local=False since we don't control what thread the indexing/pruning
# might run our callback with
lock: RedisLock = r.lock(
redis_connector_index.generator_lock_key,
lock = r.lock(
rci.generator_lock_key,
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
thread_local=False,
)
acquired = lock.acquire(blocking=False)
if not acquired:
logger.warning(
task_logger.warning(
f"Indexing task already running, exiting...: "
f"index_attempt={index_attempt_id} cc_pair={cc_pair_id} search_settings={search_settings_id}"
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
)
# r.set(rci.generator_complete_key, HTTPStatus.CONFLICT.value)
return None
payload.started = datetime.now(timezone.utc)
redis_connector_index.set_fence(payload)
fence_data.started = datetime.now(timezone.utc)
r.set(rci.fence_key, fence_data.model_dump_json())
try:
with get_session_with_tenant(tenant_id) as db_session:
@@ -815,7 +526,6 @@ def connector_indexing_task(
raise ValueError(
f"Index attempt not found: index_attempt={index_attempt_id}"
)
attempt_found = True
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
@@ -835,52 +545,43 @@ def connector_indexing_task(
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}"
)
# define a callback class
callback = IndexingCallback(
redis_connector.stop.fence_key,
redis_connector_index.generator_progress_key,
lock,
r,
)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
logger.info(
f"Indexing spawned task running entrypoint: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
# define a callback class
callback = RunIndexingCallback(
rcs.fence_key, rci.generator_progress_key, lock, r
)
run_indexing_entrypoint(
index_attempt_id,
tenant_id,
cc_pair_id,
is_ee,
callback=callback,
)
run_indexing_entrypoint(
index_attempt_id,
tenant_id,
cc_pair_id,
is_ee,
callback=callback,
)
# get back the total number of indexed docs and return it
n_final_progress = redis_connector_index.get_progress()
redis_connector_index.set_generator_complete(HTTPStatus.OK.value)
# get back the total number of indexed docs and return it
generator_progress_value = r.get(rci.generator_progress_key)
if generator_progress_value is not None:
try:
n_final_progress = int(cast(int, generator_progress_value))
except ValueError:
pass
r.set(rci.generator_complete_key, HTTPStatus.OK.value)
except Exception as e:
logger.exception(
f"Indexing spawned task failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
if attempt_found:
task_logger.exception(f"Indexing failed: cc_pair={cc_pair_id}")
if attempt:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_failed(index_attempt_id, db_session, failure_reason=str(e))
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
r.delete(rci.generator_lock_key)
r.delete(rci.generator_progress_key)
r.delete(rci.taskset_key)
r.delete(rci.fence_key)
raise e
finally:
if lock.owned():
lock.release()
logger.info(
f"Indexing spawned task finished: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return n_final_progress

View File

@@ -13,13 +13,12 @@ from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import PostgresAdvisoryLocks
from danswer.db.engine import get_session_with_tenant
@shared_task(
name=DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
name="kombu_message_cleanup_task",
soft_time_limit=JOB_TIMEOUT,
bind=True,
base=AbortableTask,

View File

@@ -8,12 +8,14 @@ from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.tasks.indexing.tasks import IndexingCallback
from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
@@ -21,7 +23,6 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import InputType
@@ -32,7 +33,6 @@ from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import pruning_ctx
from danswer.utils.logger import setup_logger
@@ -40,44 +40,8 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if pruning is due.
Next pruning time is calculated as a delta from the last successful prune, or the
last successful indexing if pruning has never succeeded.
TODO(rkuo): consider whether we should allow pruning to be immediately rescheduled
if pruning fails (which is what it does now). A backoff could be reasonable.
"""
# skip pruning if no prune frequency is set
# pruning can still be forced via the API which will run a pruning task directly
if not cc_pair.connector.prune_freq:
return False
# skip pruning if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
# skip pruning if the next scheduled prune time hasn't been reached yet
last_pruned = cc_pair.last_pruned
if not last_pruned:
if not cc_pair.last_successful_index_time:
# if we've never indexed, we can't prune
return False
# if never pruned, use the last time the connector indexed successfully
last_pruned = cc_pair.last_successful_index_time
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
if datetime.now(timezone.utc) < next_prune:
return False
return True
@shared_task(
name=DanswerCeleryTask.CHECK_FOR_PRUNING,
name="check_for_pruning",
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
@@ -107,7 +71,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
if not cc_pair:
continue
if not _is_pruning_due(cc_pair):
if not is_pruning_due(cc_pair, db_session, r):
continue
tasks_created = try_creating_prune_generator_task(
@@ -128,6 +92,47 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
lock_beat.release()
def is_pruning_due(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
) -> bool:
"""Returns an int if pruning is triggered.
The int represents the number of prune tasks generated (in this case, only one
because the task is a long running generator task.)
Returns None if no pruning is triggered (due to not being needed or
other reasons such as simultaneous pruning restrictions.
Checks for scheduling related conditions, then delegates the rest of the checks to
try_creating_prune_generator_task.
"""
# skip pruning if no prune frequency is set
# pruning can still be forced via the API which will run a pruning task directly
if not cc_pair.connector.prune_freq:
return False
# skip pruning if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
# skip pruning if the next scheduled prune time hasn't been reached yet
last_pruned = cc_pair.last_pruned
if not last_pruned:
if not cc_pair.last_successful_index_time:
# if we've never indexed, we can't prune
return False
# if never pruned, use the last time the connector indexed successfully
last_pruned = cc_pair.last_successful_index_time
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
if datetime.now(timezone.utc) < next_prune:
return False
return True
def try_creating_prune_generator_task(
celery_app: Celery,
cc_pair: ConnectorCredentialPair,
@@ -142,11 +147,8 @@ def try_creating_prune_generator_task(
is used to trigger prunes immediately, e.g. via the web ui.
"""
redis_connector = RedisConnector(tenant_id, cc_pair.id)
if not ALLOW_SIMULTANEOUS_PRUNING:
count = redis_connector.prune.get_active_task_count()
if count > 0:
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
return None
LOCK_TIMEOUT = 30
@@ -163,16 +165,15 @@ def try_creating_prune_generator_task(
return None
try:
rcp = RedisConnectorPruning(cc_pair.id)
# skip pruning if already pruning
if redis_connector.prune.fenced:
if r.exists(rcp.fence_key):
return None
# skip pruning if the cc_pair is deleting
if redis_connector.delete.fenced:
return None
# skip pruning if doc permissions sync is running
if redis_connector.permissions.fenced:
rcd = RedisConnectorDeletion(cc_pair.id)
if r.exists(rcd.fence_key):
return None
db_session.refresh(cc_pair)
@@ -180,13 +181,13 @@ def try_creating_prune_generator_task(
return None
# add a long running generator task to the queue
redis_connector.prune.generator_clear()
redis_connector.prune.taskset_clear()
r.delete(rcp.generator_complete_key)
r.delete(rcp.taskset_key)
custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}"
custom_task_id = f"{rcp.generator_task_id_prefix}_{uuid4()}"
celery_app.send_task(
DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
"connector_pruning_generator_task",
kwargs=dict(
cc_pair_id=cc_pair.id,
connector_id=cc_pair.connector_id,
@@ -199,7 +200,7 @@ def try_creating_prune_generator_task(
)
# set this only after all tasks have been added
redis_connector.prune.set_fence(True)
r.set(rcp.fence_key, 1)
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}")
return None
@@ -211,7 +212,7 @@ def try_creating_prune_generator_task(
@shared_task(
name=DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
name="connector_pruning_generator_task",
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,
@@ -234,18 +235,13 @@ def connector_pruning_generator_task(
pruning_ctx_dict["request_id"] = self.request.id
pruning_ctx.set(pruning_ctx_dict)
task_logger.info(f"Pruning generator starting: cc_pair={cc_pair_id}")
redis_connector = RedisConnector(tenant_id, cc_pair_id)
rcp = RedisConnectorPruning(cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
# set thread_local=False since we don't control what thread the indexing/pruning
# might run our callback with
lock: RedisLock = r.lock(
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
lock = r.lock(
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{rcp._id}",
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
thread_local=False,
)
acquired = lock.acquire(blocking=False)
@@ -269,11 +265,6 @@ def connector_pruning_generator_task(
)
return
task_logger.info(
f"Pruning generator running connector: "
f"cc_pair={cc_pair_id} "
f"connector_source={cc_pair.connector.source}"
)
runnable_connector = instantiate_connector(
db_session,
cc_pair.connector.source,
@@ -282,13 +273,11 @@ def connector_pruning_generator_task(
cc_pair.credential,
)
callback = IndexingCallback(
redis_connector.stop.fence_key,
redis_connector.prune.generator_progress_key,
lock,
r,
)
rcs = RedisConnectorStop(cc_pair_id)
callback = RunIndexingCallback(
rcs.fence_key, rcp.generator_progress_key, lock, r
)
# a list of docs in the source
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
runnable_connector, callback
@@ -310,34 +299,36 @@ def connector_pruning_generator_task(
task_logger.info(
f"Pruning set collected: "
f"cc_pair={cc_pair_id} "
f"connector_source={cc_pair.connector.source} "
f"docs_to_remove={len(doc_ids_to_remove)}"
f"docs_to_remove={len(doc_ids_to_remove)} "
f"doc_source={cc_pair.connector.source}"
)
rcp.documents_to_prune = set(doc_ids_to_remove)
task_logger.info(
f"RedisConnector.prune.generate_tasks starting. cc_pair={cc_pair_id}"
f"RedisConnectorPruning.generate_tasks starting. cc_pair={cc_pair.id}"
)
tasks_generated = redis_connector.prune.generate_tasks(
set(doc_ids_to_remove), self.app, db_session, None
tasks_generated = rcp.generate_tasks(
self.app, db_session, r, None, tenant_id
)
if tasks_generated is None:
return None
task_logger.info(
f"RedisConnector.prune.generate_tasks finished. "
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
f"RedisConnectorPruning.generate_tasks finished. "
f"cc_pair={cc_pair.id} tasks_generated={tasks_generated}"
)
redis_connector.prune.generator_complete = tasks_generated
r.set(rcp.generator_complete_key, tasks_generated)
except Exception as e:
task_logger.exception(
f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
)
redis_connector.prune.reset()
r.delete(rcp.generator_progress_key)
r.delete(rcp.taskset_key)
r.delete(rcp.fence_key)
raise e
finally:
if lock.owned():
lock.release()
task_logger.info(f"Pruning generator finished: cc_pair={cc_pair_id}")

View File

@@ -0,0 +1,8 @@
from datetime import datetime
from pydantic import BaseModel
class RedisConnectorDeletionFenceData(BaseModel):
num_tasks: int | None
submitted: datetime

View File

@@ -0,0 +1,10 @@
from datetime import datetime
from pydantic import BaseModel
class RedisConnectorIndexingFenceData(BaseModel):
index_attempt_id: int | None
started: datetime | None
submitted: datetime
celery_task_id: str | None

View File

@@ -9,7 +9,6 @@ from tenacity import RetryError
from danswer.access.access import get_access_for_document
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from danswer.configs.constants import DanswerCeleryTask
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document
@@ -32,7 +31,7 @@ LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
@shared_task(
name=DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
name="document_by_cc_pair_cleanup_task",
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,
max_retries=DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES,
@@ -60,7 +59,7 @@ def document_by_cc_pair_cleanup_task(
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
task_logger.debug(f"Task start: tenant={tenant_id} doc={document_id}")
task_logger.info(f"tenant={tenant_id} doc={document_id}")
try:
with get_session_with_tenant(tenant_id) as db_session:
@@ -142,9 +141,7 @@ def document_by_cc_pair_cleanup_task(
return False
except Exception as ex:
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
task_logger.info(f"Retry failed: {ex.last_attempt.attempt_number}")
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
@@ -174,21 +171,11 @@ def document_by_cc_pair_cleanup_task(
else:
# This is the last attempt! mark the document as dirty in the db so that it
# eventually gets fixed out of band via stale document reconciliation
task_logger.warning(
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
task_logger.info(
f"Max retries reached. Marking doc as dirty for reconciliation: "
f"tenant={tenant_id} doc={document_id}"
)
with get_session_with_tenant(tenant_id) as db_session:
# delete the cc pair relationship now and let reconciliation clean it up
# in vespa
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_id=document_id,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
with get_session_with_tenant(tenant_id):
mark_document_as_modified(document_id, db_session)
return False

View File

@@ -5,6 +5,7 @@ from http import HTTPStatus
from typing import cast
import httpx
import redis
from celery import Celery
from celery import shared_task
from celery import Task
@@ -12,23 +13,32 @@ from celery.exceptions import SoftTimeLimitExceeded
from celery.result import AsyncResult
from celery.states import READY_STATES
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from tenacity import RetryError
from danswer.access.access import get_access_for_document
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import celery_get_queue_length
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
RedisConnectorDeletionFenceData,
)
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
RedisConnectorIndexingFenceData,
)
from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from danswer.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from danswer.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector import mark_cc_pair_as_permissions_synced
from danswer.db.connector import mark_ccpair_as_pruned
from danswer.db.connector_credential_pair import add_deletion_failure_message
from danswer.db.connector_credential_pair import (
@@ -49,24 +59,15 @@ from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import IndexingStatus
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import DocumentSet
from danswer.db.models import IndexAttempt
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncPayload,
)
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_document_set import RedisDocumentSet
from danswer.redis.redis_pool import get_redis_client
from danswer.redis.redis_usergroup import RedisUserGroup
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import (
@@ -81,7 +82,7 @@ logger = setup_logger()
# celery auto associates tasks created inside another task,
# which bloats the result metadata considerably. trail=False prevents this.
@shared_task(
name=DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
name="check_for_vespa_sync_task",
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
@@ -166,7 +167,7 @@ def try_generate_stale_document_sync_tasks(
celery_app: Celery,
db_session: Session,
r: Redis,
lock_beat: RedisLock,
lock_beat: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
# the fence is up, do nothing
@@ -184,34 +185,30 @@ def try_generate_stale_document_sync_tasks(
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
)
task_logger.info(
"RedisConnector.generate_tasks starting by cc_pair. "
"Documents spanning multiple cc_pairs will only be synced once."
)
docs_to_skip: set[str] = set()
task_logger.info("RedisConnector.generate_tasks starting by cc_pair.")
# rkuo: we could technically sync all stale docs in one big pass.
# but I feel it's more understandable to group the docs by cc_pair
total_tasks_generated = 0
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
rc.set_skip_docs(docs_to_skip)
result = rc.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
rc = RedisConnectorCredentialPair(cc_pair.id)
tasks_generated = rc.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if result is None:
if tasks_generated is None:
continue
if result[1] == 0:
if tasks_generated == 0:
continue
task_logger.info(
f"RedisConnector.generate_tasks finished for single cc_pair. "
f"cc_pair={cc_pair.id} tasks_generated={result[0]} tasks_possible={result[1]}"
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
)
total_tasks_generated += result[0]
total_tasks_generated += tasks_generated
task_logger.info(
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
@@ -226,15 +223,15 @@ def try_generate_document_set_sync_tasks(
document_set_id: int,
db_session: Session,
r: Redis,
lock_beat: RedisLock,
lock_beat: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
lock_beat.reacquire()
rds = RedisDocumentSet(tenant_id, document_set_id)
rds = RedisDocumentSet(document_set_id)
# don't generate document set sync tasks if tasks are still pending
if rds.fenced:
if r.exists(rds.fence_key):
return None
# don't generate sync tasks if we're up to date
@@ -254,11 +251,12 @@ def try_generate_document_set_sync_tasks(
)
# Add all documents that need to be updated into the queue
result = rds.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
if result is None:
tasks_generated = rds.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
return None
tasks_generated = result[0]
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
@@ -267,11 +265,11 @@ def try_generate_document_set_sync_tasks(
task_logger.info(
f"RedisDocumentSet.generate_tasks finished. "
f"document_set={document_set.id} tasks_generated={tasks_generated}"
f"document_set_id={document_set.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
rds.set_fence(tasks_generated)
r.set(rds.fence_key, tasks_generated)
return tasks_generated
@@ -280,14 +278,15 @@ def try_generate_user_group_sync_tasks(
usergroup_id: int,
db_session: Session,
r: Redis,
lock_beat: RedisLock,
lock_beat: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
lock_beat.reacquire()
rug = RedisUserGroup(tenant_id, usergroup_id)
if rug.fenced:
# don't generate sync tasks if tasks are still pending
rug = RedisUserGroup(usergroup_id)
# don't generate sync tasks if tasks are still pending
if r.exists(rug.fence_key):
return None
# race condition with the monitor/cleanup function if we use a cached result!
@@ -309,11 +308,12 @@ def try_generate_user_group_sync_tasks(
task_logger.info(
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
)
result = rug.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
if result is None:
tasks_generated = rug.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
return None
tasks_generated = result[0]
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
@@ -322,11 +322,11 @@ def try_generate_user_group_sync_tasks(
task_logger.info(
f"RedisUserGroup.generate_tasks finished. "
f"usergroup={usergroup.id} tasks_generated={tasks_generated}"
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
rug.set_fence(tasks_generated)
r.set(rug.fence_key, tasks_generated)
return tasks_generated
@@ -352,7 +352,7 @@ def monitor_connector_taskset(r: Redis) -> None:
def monitor_document_set_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
@@ -362,12 +362,16 @@ def monitor_document_set_taskset(
document_set_id = int(document_set_id_str)
rds = RedisDocumentSet(tenant_id, document_set_id)
if not rds.fenced:
rds = RedisDocumentSet(document_set_id)
fence_value = r.get(rds.fence_key)
if fence_value is None:
return
initial_count = rds.payload
if initial_count is None:
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rds.taskset_key))
@@ -395,38 +399,48 @@ def monitor_document_set_taskset(
f"Successfully synced document set: document_set={document_set_id}"
)
rds.reset()
r.delete(rds.taskset_key)
r.delete(rds.fence_key)
def monitor_connector_deletion_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis
key_bytes: bytes, r: Redis, tenant_id: str | None
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
cc_pair_id_str = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
rcd = RedisConnectorDeletion(cc_pair_id)
fence_data = redis_connector.delete.payload
if not fence_data:
task_logger.warning(
f"Connector deletion - fence payload invalid: cc_pair={cc_pair_id}"
# read related data and evaluate/print task progress
fence_value = cast(bytes, r.get(rcd.fence_key))
if fence_value is None:
return
try:
fence_json = fence_value.decode("utf-8")
fence_data = RedisConnectorDeletionFenceData.model_validate_json(
cast(str, fence_json)
)
return
except ValueError:
task_logger.exception(
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
)
raise
# the fence is setting up but isn't ready yet
if fence_data.num_tasks is None:
# the fence is setting up but isn't ready yet
return
remaining = redis_connector.delete.get_remaining()
count = cast(int, r.scard(rcd.taskset_key))
task_logger.info(
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={fence_data.num_tasks}"
)
if remaining > 0:
if count > 0:
return
with get_session_with_tenant(tenant_id) as db_session:
@@ -442,22 +456,11 @@ def monitor_connector_deletion_taskset(
db_session, cc_pair.connector_id, cc_pair.credential_id
)
if len(doc_ids) > 0:
# NOTE(rkuo): if this happens, documents somehow got added while
# deletion was in progress. Likely a bug gating off pruning and indexing
# work before deletion starts.
# if this happens, documents somehow got added while deletion was in progress. Likely a bug
# gating off pruning and indexing work before deletion starts
task_logger.warning(
"Connector deletion - documents still found after taskset completion. "
"Clearing the current deletion attempt and allowing deletion to restart: "
f"cc_pair={cc_pair_id} "
f"docs_deleted={fence_data.num_tasks} "
f"docs_remaining={len(doc_ids)}"
)
# We don't want to waive off why we get into this state, but resetting
# our attempt and letting the deletion restart is a good way to recover
redis_connector.delete.reset()
raise RuntimeError(
"Connector deletion - documents still found after taskset completion"
f"Connector deletion - documents still found after taskset completion: "
f"cc_pair={cc_pair_id} num={len(doc_ids)}"
)
# clean up the rest of the related Postgres entities
@@ -521,14 +524,15 @@ def monitor_connector_deletion_taskset(
f"docs_deleted={fence_data.num_tasks}"
)
redis_connector.delete.reset()
r.delete(rcd.taskset_key)
r.delete(rcd.fence_key)
def monitor_ccpair_pruning_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
cc_pair_id_str = RedisConnectorPruning.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
@@ -537,76 +541,46 @@ def monitor_ccpair_pruning_taskset(
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if not redis_connector.prune.fenced:
rcp = RedisConnectorPruning(cc_pair_id)
fence_value = r.get(rcp.fence_key)
if fence_value is None:
return
initial = redis_connector.prune.generator_complete
if initial is None:
generator_value = r.get(rcp.generator_complete_key)
if generator_value is None:
return
remaining = redis_connector.prune.get_remaining()
try:
initial_count = int(cast(int, generator_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rcp.taskset_key))
task_logger.info(
f"Connector pruning progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
f"Connector pruning progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
)
if remaining > 0:
if count > 0:
return
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
task_logger.info(
f"Successfully pruned connector credential pair. cc_pair={cc_pair_id}"
f"Successfully pruned connector credential pair. cc_pair_id={cc_pair_id}"
)
redis_connector.prune.taskset_clear()
redis_connector.prune.generator_clear()
redis_connector.prune.set_fence(False)
def monitor_ccpair_permissions_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if not redis_connector.permissions.fenced:
return
initial = redis_connector.permissions.generator_complete
if initial is None:
return
remaining = redis_connector.permissions.get_remaining()
task_logger.info(
f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
)
if remaining > 0:
return
payload: RedisConnectorPermissionSyncPayload | None = (
redis_connector.permissions.payload
)
start_time: datetime | None = payload.started if payload else None
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
redis_connector.permissions.reset()
r.delete(rcp.taskset_key)
r.delete(rcp.generator_progress_key)
r.delete(rcp.generator_complete_key)
r.delete(rcp.fence_key)
def monitor_ccpair_indexing_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
composite_id = RedisConnectorIndexing.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"monitor_ccpair_indexing_taskset: could not parse composite_id from {fence_key}"
@@ -621,94 +595,103 @@ def monitor_ccpair_indexing_taskset(
cc_pair_id = int(parts[0])
search_settings_id = int(parts[1])
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
if not redis_connector_index.fenced:
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
# read related data and evaluate/print task progress
fence_value = cast(bytes, r.get(rci.fence_key))
if fence_value is None:
return
payload = redis_connector_index.payload
if not payload:
return
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
progress = redis_connector_index.get_progress()
if progress is not None:
task_logger.info(
f"Connector indexing progress: cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
try:
fence_json = fence_value.decode("utf-8")
fence_data = RedisConnectorIndexingFenceData.model_validate_json(
cast(str, fence_json)
)
except ValueError:
task_logger.exception(
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
)
raise
if payload.index_attempt_id is None or payload.celery_task_id is None:
elapsed_submitted = datetime.now(timezone.utc) - fence_data.submitted
generator_progress_value = r.get(rci.generator_progress_key)
if generator_progress_value is not None:
try:
progress_count = int(cast(int, generator_progress_value))
task_logger.info(
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"progress={progress_count} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
except ValueError:
task_logger.error(
"monitor_ccpair_indexing_taskset: generator_progress_value is not an integer."
)
if fence_data.index_attempt_id is None or fence_data.celery_task_id is None:
# the task is still setting up
return
# Read result state BEFORE generator_complete_key to avoid a race condition
# never use any blocking methods on the result from inside a task!
result: AsyncResult = AsyncResult(payload.celery_task_id)
result: AsyncResult = AsyncResult(fence_data.celery_task_id)
result_state = result.state
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
generator_complete_value = r.get(rci.generator_complete_key)
if generator_complete_value is None:
if result_state in READY_STATES:
# IF the task state is READY, THEN generator_complete should be set
# if it isn't, then the worker crashed
task_logger.info(
f"Connector indexing aborted: "
f"cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
# inner = get_completion / generator_complete not signaled
# outer = result.state in READY state
status_int = redis_connector_index.get_completion()
if status_int is None: # inner signal not set ... possible error
task_state = result.state
if (
task_state in READY_STATES
): # outer signal in terminal state ... possible error
# Now double check!
if redis_connector_index.get_completion() is None:
# inner signal still not set (and cannot change when outer result_state is READY)
# Task is finished but generator complete isn't set.
# We have a problem! Worker may have crashed.
task_result = str(result.result)
task_traceback = str(result.traceback)
msg = (
f"Connector indexing aborted or exceptioned: "
f"attempt={payload.index_attempt_id} "
f"celery_task={payload.celery_task_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"result.state={task_state} "
f"result.result={task_result} "
f"result.traceback={task_traceback}"
index_attempt = get_index_attempt(db_session, fence_data.index_attempt_id)
if index_attempt:
mark_attempt_failed(
index_attempt=index_attempt,
db_session=db_session,
failure_reason="Connector indexing aborted or exceptioned.",
)
task_logger.warning(msg)
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
if index_attempt:
if (
index_attempt.status != IndexingStatus.CANCELED
and index_attempt.status != IndexingStatus.FAILED
):
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)
redis_connector_index.reset()
r.delete(rci.generator_lock_key)
r.delete(rci.taskset_key)
r.delete(rci.generator_progress_key)
r.delete(rci.generator_complete_key)
r.delete(rci.fence_key)
return
status_enum = HTTPStatus(status_int)
status_enum = HTTPStatus.INTERNAL_SERVER_ERROR
try:
status_value = int(cast(int, generator_complete_value))
status_enum = HTTPStatus(status_value)
except ValueError:
task_logger.error(
f"monitor_ccpair_indexing_taskset: "
f"generator_complete_value=f{generator_complete_value} could not be parsed."
)
task_logger.info(
f"Connector indexing finished: cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"Connector indexing finished: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
redis_connector_index.reset()
r.delete(rci.generator_lock_key)
r.delete(rci.taskset_key)
r.delete(rci.generator_progress_key)
r.delete(rci.generator_complete_key)
r.delete(rci.fence_key)
@shared_task(name=DanswerCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
It scans for fence values and then gets the counts of any associated tasksets.
@@ -717,11 +700,11 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
do anything too expensive in this function!
Returns True if the task actually did work, False if it exited early to prevent overlap
Returns True if the task actually did work, False
"""
r = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
lock_beat: redis.lock.Lock = r.lock(
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -733,7 +716,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
# print current queue lengths
r_celery = self.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r_celery)
n_celery = celery_get_queue_length("celery", r)
n_indexing = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery
)
@@ -746,33 +729,49 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
n_pruning = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_PRUNING, r_celery
)
n_permissions_sync = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
)
task_logger.info(
f"Queue lengths: celery={n_celery} "
f"indexing={n_indexing} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning} "
f"permissions_sync={n_permissions_sync} "
f"pruning={n_pruning}"
)
# do some cleanup before clearing fences
# check the db for any outstanding index attempts
with get_session_with_tenant(tenant_id) as db_session:
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for a in attempts:
# if attempts exist in the db but we don't detect them in redis, mark them as failed
rci = RedisConnectorIndexing(
a.connector_credential_pair_id, a.search_settings_id
)
failure_reason = f"Unknown index attempt {a.id}. Might be left over from a process restart."
if not r.exists(rci.fence_key):
mark_attempt_failed(a, db_session, failure_reason=failure_reason)
lock_beat.reacquire()
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"):
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
lock_beat.reacquire()
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
monitor_connector_deletion_taskset(key_bytes, r, tenant_id)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
monitor_document_set_taskset(key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
@@ -783,25 +782,19 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
noop_fallback,
)
with get_session_with_tenant(tenant_id) as db_session:
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
monitor_usergroup_taskset(key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
for key_bytes in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
monitor_ccpair_indexing_taskset(key_bytes, r, db_session)
# uncomment for debugging if needed
# r_celery = celery_app.broker_connection().channel().client
@@ -819,7 +812,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
@shared_task(
name=DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
name="vespa_metadata_sync_task",
bind=True,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,
@@ -873,9 +866,7 @@ def vespa_metadata_sync_task(
)
except Exception as ex:
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
task_logger.warning(f"Retry failed: {ex.last_attempt.attempt_number}")
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()

View File

@@ -1,8 +1,8 @@
"""Factory stub for running celery worker / celery beat."""
from celery import Celery
from danswer.background.celery.apps.beat import celery_app
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app: Celery = celery_app
app = fetch_versioned_implementation(
"danswer.background.celery.apps.beat", "celery_app"
)

View File

@@ -1,10 +1,8 @@
"""Factory stub for running celery worker / celery beat."""
from celery import Celery
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app: Celery = fetch_versioned_implementation(
app = fetch_versioned_implementation(
"danswer.background.celery.apps.primary", "celery_app"
)

View File

@@ -29,26 +29,18 @@ JobStatusType = (
def _initializer(
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
) -> Any:
"""Initialize the child process with a fresh SQLAlchemy Engine.
"""Ensure the parent proc's database connections are not touched
in the new connection pool
Based on SQLAlchemy's recommendations to handle multiprocessing:
Based on the recommended approach in the SQLAlchemy docs found:
https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
"""
if kwargs is None:
kwargs = {}
logger.info("Initializing spawned worker child process.")
# Reset the engine in the child process
SqlEngine.reset_engine()
# Optionally set a custom app name for database logging purposes
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
# Initialize a new engine with desired parameters
SqlEngine.init_engine(pool_size=4, max_overflow=12, pool_recycle=60)
# Proceed with executing the target function
return func(*args, **kwargs)

View File

@@ -1,5 +1,7 @@
import time
import traceback
from abc import ABC
from abc import abstractmethod
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -19,7 +21,6 @@ from danswer.db.connector_credential_pair import get_last_successful_attempt_tim
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_partially_succeeded
from danswer.db.index_attempt import mark_attempt_succeeded
@@ -30,10 +31,10 @@ from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import IndexAttemptSingleton
from danswer.utils.logger import setup_logger
from danswer.utils.logger import TaskAttemptSingleton
from danswer.utils.variable_functionality import global_version
logger = setup_logger()
@@ -41,6 +42,19 @@ logger = setup_logger()
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
class RunIndexingCallbackInterface(ABC):
"""Defines a callback interface to be passed to
to run_indexing_entrypoint."""
@abstractmethod
def should_stop(self) -> bool:
"""Signal to stop the looping function in flight."""
@abstractmethod
def progress(self, amount: int) -> None:
"""Send progress updates to the caller."""
def _get_connector_runner(
db_session: Session,
attempt: IndexAttempt,
@@ -88,15 +102,11 @@ def _get_connector_runner(
)
class ConnectorStopSignal(Exception):
"""A custom exception used to signal a stop in processing."""
def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
tenant_id: str | None,
callback: IndexingHeartbeatInterface | None = None,
callback: RunIndexingCallbackInterface | None = None,
) -> None:
"""
1. Get documents which are either new or updated from specified application
@@ -108,13 +118,7 @@ def _run_indexing(
"""
start_time = time.time()
if index_attempt.search_settings is None:
raise ValueError(
"Search settings must be set for indexing. This should not be possible."
)
search_settings = index_attempt.search_settings
index_name = search_settings.index_name
# Only update cc-pair status for primary index jobs
@@ -128,7 +132,13 @@ def _run_indexing(
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings,
callback=callback,
heartbeat=IndexingHeartbeat(
index_attempt_id=index_attempt.id,
db_session=db_session,
# let the world know we're still making progress after
# every 10 batches
freq=10,
),
)
indexing_pipeline = build_indexing_pipeline(
@@ -141,7 +151,6 @@ def _run_indexing(
),
db_session=db_session,
tenant_id=tenant_id,
callback=callback,
)
db_cc_pair = index_attempt.connector_credential_pair
@@ -213,7 +222,7 @@ def _run_indexing(
# contents still need to be initially pulled.
if callback:
if callback.should_stop():
raise ConnectorStopSignal("Connector stop signal detected")
raise RuntimeError("Connector stop signal detected")
# TODO: should we move this into the above callback instead?
db_session.refresh(db_cc_pair)
@@ -274,7 +283,7 @@ def _run_indexing(
db_session.commit()
if callback:
callback.progress("_run_indexing", len(doc_batch))
callback.progress(len(doc_batch))
# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(
@@ -307,16 +316,26 @@ def _run_indexing(
)
except Exception as e:
logger.exception(
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
)
if isinstance(e, ConnectorStopSignal):
mark_attempt_canceled(
index_attempt.id,
# Only mark the attempt as a complete failure if this is the first indexing window.
# Otherwise, some progress was made - the next run will not start from the beginning.
# In this case, it is not accurate to mark it as a failure. When the next run begins,
# if that fails immediately, it will be marked as a failure.
#
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or not db_cc_pair.status.is_active()
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
index_attempt,
db_session,
reason=str(e),
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
@@ -328,37 +347,6 @@ def _run_indexing(
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
else:
# Only mark the attempt as a complete failure if this is the first indexing window.
# Otherwise, some progress was made - the next run will not start from the beginning.
# In this case, it is not accurate to mark it as a failure. When the next run begins,
# if that fails immediately, it will be marked as a failure.
#
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or not db_cc_pair.status.is_active()
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
index_attempt.id,
db_session,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
# break => similar to success case. As mentioned above, if the next run fails for the same
# reason it will then be marked as a failure
@@ -378,7 +366,7 @@ def _run_indexing(
and index_attempt_md.num_exceptions >= batch_num
):
mark_attempt_failed(
index_attempt.id,
index_attempt,
db_session,
failure_reason="All batches exceptioned.",
)
@@ -425,7 +413,7 @@ def run_indexing_entrypoint(
tenant_id: str | None,
connector_credential_pair_id: int,
is_ee: bool = False,
callback: IndexingHeartbeatInterface | None = None,
callback: RunIndexingCallbackInterface | None = None,
) -> None:
try:
if is_ee:
@@ -433,19 +421,17 @@ def run_indexing_entrypoint(
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
TaskAttemptSingleton.set_cc_and_index_id(
IndexAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
)
with get_session_with_tenant(tenant_id) as db_session:
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
tenant_str = ""
if tenant_id is not None:
tenant_str = f" for tenant {tenant_id}"
logger.info(
f"Indexing starting{tenant_str}: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"Indexing starting for tenant {tenant_id}: "
if tenant_id is not None
else ""
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
@@ -453,8 +439,10 @@ def run_indexing_entrypoint(
_run_indexing(db_session, attempt, tenant_id, callback)
logger.info(
f"Indexing finished{tenant_str}: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"Indexing finished for tenant {tenant_id}: "
if tenant_id is not None
else ""
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)

View File

@@ -14,6 +14,15 @@ from danswer.db.tasks import mark_task_start
from danswer.db.tasks import register_task
def name_cc_prune_task(
connector_id: int | None = None, credential_id: int | None = None
) -> str:
task_name = f"prune_connector_credential_pair_{connector_id}_{credential_id}"
if not connector_id or not credential_id:
task_name = "prune_connector_credential_pair"
return task_name
T = TypeVar("T", bound=Callable)

View File

@@ -2,79 +2,20 @@ import re
from typing import cast
from uuid import UUID
from fastapi import HTTPException
from fastapi.datastructures import Headers
from sqlalchemy.orm import Session
from danswer.auth.users import is_user_admin
from danswer.chat.models import CitationInfo
from danswer.chat.models import LlmDoc
from danswer.chat.models import PersonaOverrideConfig
from danswer.chat.models import ThreadMessage
from danswer.configs.constants import DEFAULT_PERSONA_ID
from danswer.configs.constants import MessageType
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import RetrievalDetails
from danswer.db.chat import create_chat_session
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.llm import fetch_existing_doc_sets
from danswer.db.llm import fetch_existing_tools
from danswer.db.models import ChatMessage
from danswer.db.models import Persona
from danswer.db.models import Prompt
from danswer.db.models import Tool
from danswer.db.models import User
from danswer.db.persona import get_prompts_by_ids
from danswer.llm.answering.models import PreviousMessage
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.search.models import InferenceSection
from danswer.utils.logger import setup_logger
logger = setup_logger()
def prepare_chat_message_request(
message_text: str,
user: User | None,
persona_id: int | None,
# Does the question need to have a persona override
persona_override_config: PersonaOverrideConfig | None,
prompt: Prompt | None,
message_ts_to_respond_to: str | None,
retrieval_details: RetrievalDetails | None,
rerank_settings: RerankingDetails | None,
db_session: Session,
) -> CreateChatMessageRequest:
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
new_chat_session = create_chat_session(
db_session=db_session,
description=None,
user_id=user.id if user else None,
# If using an override, this id will be ignored later on
persona_id=persona_id or DEFAULT_PERSONA_ID,
danswerbot_flow=True,
slack_thread_id=message_ts_to_respond_to,
)
return CreateChatMessageRequest(
chat_session_id=new_chat_session.id,
parent_message_id=None, # It's a standalone chat session each time
message=message_text,
file_descriptors=[], # Currently SlackBot/answer api do not support files in the context
prompt_id=prompt.id if prompt else None,
# Can always override the persona for the single query, if it's a normal persona
# then it will be treated the same
persona_override_config=persona_override_config,
search_doc_ids=None,
retrieval_options=retrieval_details,
rerank_settings=rerank_settings,
)
def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDoc:
return LlmDoc(
document_id=inference_section.center_chunk.document_id,
@@ -90,49 +31,9 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
if inference_section.center_chunk.source_links
else None,
source_links=inference_section.center_chunk.source_links,
match_highlights=inference_section.center_chunk.match_highlights,
)
def combine_message_thread(
messages: list[ThreadMessage],
max_tokens: int | None,
llm_tokenizer: BaseTokenizer,
) -> str:
"""Used to create a single combined message context from threads"""
if not messages:
return ""
message_strs: list[str] = []
total_token_count = 0
for message in reversed(messages):
if message.role == MessageType.USER:
role_str = message.role.value.upper()
if message.sender:
role_str += " " + message.sender
else:
# Since other messages might have the user identifying information
# better to use Unknown for symmetry
role_str += " Unknown"
else:
role_str = message.role.value.upper()
msg_str = f"{role_str}:\n{message.message}"
message_token_count = len(llm_tokenizer.encode(msg_str))
if (
max_tokens is not None
and total_token_count + message_token_count > max_tokens
):
break
message_strs.insert(0, msg_str)
total_token_count += message_token_count
return "\n\n".join(message_strs)
def create_chat_chain(
chat_session_id: UUID,
db_session: Session,
@@ -295,71 +196,3 @@ def extract_headers(
if lowercase_key in headers:
extracted_headers[lowercase_key] = headers[lowercase_key]
return extracted_headers
def create_temporary_persona(
persona_config: PersonaOverrideConfig, db_session: Session, user: User | None = None
) -> Persona:
if not is_user_admin(user):
raise HTTPException(
status_code=403,
detail="User is not authorized to create a persona in one shot queries",
)
"""Create a temporary Persona object from the provided configuration."""
persona = Persona(
name=persona_config.name,
description=persona_config.description,
num_chunks=persona_config.num_chunks,
llm_relevance_filter=persona_config.llm_relevance_filter,
llm_filter_extraction=persona_config.llm_filter_extraction,
recency_bias=persona_config.recency_bias,
llm_model_provider_override=persona_config.llm_model_provider_override,
llm_model_version_override=persona_config.llm_model_version_override,
)
if persona_config.prompts:
persona.prompts = [
Prompt(
name=p.name,
description=p.description,
system_prompt=p.system_prompt,
task_prompt=p.task_prompt,
include_citations=p.include_citations,
datetime_aware=p.datetime_aware,
)
for p in persona_config.prompts
]
elif persona_config.prompt_ids:
persona.prompts = get_prompts_by_ids(
db_session=db_session, prompt_ids=persona_config.prompt_ids
)
persona.tools = []
if persona_config.custom_tools_openapi:
for schema in persona_config.custom_tools_openapi:
tools = cast(
list[Tool],
build_custom_tools_from_openapi_schema_and_headers(schema),
)
persona.tools.extend(tools)
if persona_config.tools:
tool_ids = [tool.id for tool in persona_config.tools]
persona.tools.extend(
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
)
if persona_config.tool_ids:
persona.tools.extend(
fetch_existing_tools(
db_session=db_session, tool_ids=persona_config.tool_ids
)
)
fetched_docs = fetch_existing_doc_sets(
db_session=db_session, doc_ids=persona_config.document_set_ids
)
persona.document_sets = fetched_docs
return persona

View File

@@ -5,7 +5,6 @@ from danswer.configs.chat_configs import INPUT_PROMPT_YAML
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import PERSONAS_YAML
from danswer.configs.chat_configs import PROMPTS_YAML
from danswer.context.search.enums import RecencyBiasSetting
from danswer.db.document_set import get_or_create_document_set_by_name
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
from danswer.db.models import DocumentSet as DocumentSetDBModel
@@ -15,6 +14,7 @@ from danswer.db.models import Tool as ToolDBModel
from danswer.db.persona import get_prompt_by_name
from danswer.db.persona import upsert_persona
from danswer.db.persona import upsert_prompt
from danswer.search.enums import RecencyBiasSetting
def load_prompts_from_yaml(
@@ -81,7 +81,6 @@ def load_personas_from_yaml(
p_id = persona.get("id")
tool_ids = []
if persona.get("image_generation"):
image_gen_tool = (
db_session.query(ToolDBModel)

View File

@@ -4,15 +4,13 @@ from enum import Enum
from typing import Any
from pydantic import BaseModel
from pydantic import Field
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MessageType
from danswer.context.search.enums import QueryFlow
from danswer.context.search.enums import RecencyBiasSetting
from danswer.context.search.enums import SearchType
from danswer.context.search.models import RetrievalDocs
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import RetrievalDocs
from danswer.search.models import SearchResponse
from danswer.tools.custom.base_tool_types import ToolResultType
class LlmDoc(BaseModel):
@@ -27,7 +25,6 @@ class LlmDoc(BaseModel):
updated_at: datetime | None
link: str | None
source_links: dict[int, str] | None
match_highlights: list[str] | None
# First chunk of info for streaming QA
@@ -120,6 +117,20 @@ class StreamingError(BaseModel):
stack_trace: str | None = None
class DanswerQuote(BaseModel):
# This is during inference so everything is a string by this point
quote: str
document_id: str
link: str | None
source_type: str
semantic_identifier: str
blurb: str
class DanswerQuotes(BaseModel):
quotes: list[DanswerQuote]
class DanswerContext(BaseModel):
content: str
document_id: str
@@ -135,23 +146,17 @@ class DanswerAnswer(BaseModel):
answer: str | None
class ThreadMessage(BaseModel):
message: str
sender: str | None = None
role: MessageType = MessageType.USER
class ChatDanswerBotResponse(BaseModel):
answer: str | None = None
citations: list[CitationInfo] | None = None
docs: QADocsResponse | None = None
class QAResponse(SearchResponse, DanswerAnswer):
quotes: list[DanswerQuote] | None
contexts: list[DanswerContexts] | None
predicted_flow: QueryFlow
predicted_search: SearchType
eval_res_valid: bool | None = None
llm_selected_doc_indices: list[int] | None = None
error_msg: str | None = None
chat_message_id: int | None = None
answer_valid: bool = True # Reflexion result, default True if Reflexion not run
class FileChatDisplay(BaseModel):
class ImageGenerationDisplay(BaseModel):
file_ids: list[str]
@@ -160,44 +165,12 @@ class CustomToolResponse(BaseModel):
tool_name: str
class ToolConfig(BaseModel):
id: int
class PromptOverrideConfig(BaseModel):
name: str
description: str = ""
system_prompt: str
task_prompt: str = ""
include_citations: bool = True
datetime_aware: bool = True
class PersonaOverrideConfig(BaseModel):
name: str
description: str
search_type: SearchType = SearchType.SEMANTIC
num_chunks: float | None = None
llm_relevance_filter: bool = False
llm_filter_extraction: bool = False
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None
prompts: list[PromptOverrideConfig] = Field(default_factory=list)
prompt_ids: list[int] = Field(default_factory=list)
document_set_ids: list[int] = Field(default_factory=list)
tools: list[ToolConfig] = Field(default_factory=list)
tool_ids: list[int] = Field(default_factory=list)
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
AnswerQuestionPossibleReturn = (
DanswerAnswerPiece
| DanswerQuotes
| CitationInfo
| DanswerContexts
| FileChatDisplay
| ImageGenerationDisplay
| CustomToolResponse
| StreamingError
| StreamStopInfo

Some files were not shown because too many files have changed in this diff Show More