Compare commits

..

50 Commits

Author SHA1 Message Date
pablodanswer
4eb53ce56f rebase needs fixing 2024-08-19 07:40:53 -07:00
pablodanswer
2fc84ed63e post rebase fix 2024-08-18 16:41:12 -07:00
pablodanswer
722d5e6e54 add sequential tool calls 2024-08-18 16:40:07 -07:00
pablodanswer
14c30d2e4d add env variable 2024-08-18 15:05:44 -07:00
pablodanswer
6abad2fdd3 robust chat session state persistence 2024-08-18 15:05:44 -07:00
pablodanswer
4691e736f6 functional new message carry-over 2024-08-18 15:05:44 -07:00
pablodanswer
5a826a527f properly reset blank screen 2024-08-18 15:05:44 -07:00
pablodanswer
f92d31df70 refactored for stop / regenerate 2024-08-18 15:05:44 -07:00
pablodanswer
1eb786897a proper margin 2024-08-18 15:05:26 -07:00
pablodanswer
72471f9e1d remove parameter 2024-08-18 15:05:26 -07:00
pablodanswer
49c335d06a squash 2024-08-18 15:05:26 -07:00
pablodanswer
fda06b7739 more robust implementation for first messages 2024-08-18 15:05:26 -07:00
pablodanswer
00d44e31b3 validated + cleaner UI 2024-08-18 15:05:26 -07:00
pablodanswer
2a42c1dd18 functional once again post rebase but quite ugly 2024-08-18 15:05:26 -07:00
pablodanswer
05cd25043e add regenerate 2024-08-18 15:05:26 -07:00
pablodanswer
abebff50bb Enable seeding of analytics via file path (#2146)
* enable seeding of analytics via file path

* remove log
2024-08-18 15:05:26 -07:00
pablodanswer
0a7e672832 add handling for poorly formatting model names (#2143) 2024-08-18 15:05:26 -07:00
pablodanswer
221ab9134c add critical error just in case 2024-08-18 15:03:04 -07:00
pablodanswer
f7134202b6 slightly more specific logs 2024-08-18 14:44:10 -07:00
pablodanswer
bea11dc3aa include logs 2024-08-18 14:33:45 -07:00
pablodanswer
374b798071 update typing 2024-08-17 13:51:52 -07:00
pablodanswer
6a2e3edfcd add synchronous wrapper to avoid hampering main event loop 2024-08-17 13:39:22 -07:00
pablodanswer
2ef1731e32 tiny formatting (remove newline) 2024-08-17 09:29:39 -07:00
pablodanswer
7d4d7a5f5d clean final message handling 2024-08-17 01:14:31 -07:00
pablodanswer
ea2f9cf625 cleaner messages 2024-08-15 17:17:03 -07:00
pablodanswer
97dc9c5e31 add back stack trace detail 2024-08-15 16:46:32 -07:00
pablodanswer
249bcd46d9 clearer 2024-08-15 16:10:56 -07:00
pablodanswer
f29b727bc7 remove comments 2024-08-15 16:10:56 -07:00
pablodanswer
31fb6c0753 improve clarity + new SSE handling utility function 2024-08-15 16:10:56 -07:00
pablodanswer
a45e72c298 update utility + copy 2024-08-15 16:10:56 -07:00
pablodanswer
157548817c slightly more robust chat state 2024-08-15 16:10:56 -07:00
pablodanswer
d9396f77d1 remove false comment 2024-08-15 16:10:56 -07:00
pablodanswer
7bae6bbf8f remove log 2024-08-15 16:10:56 -07:00
pablodanswer
1d535769ed robustify 2024-08-15 16:10:56 -07:00
pablodanswer
8584a81fe2 unnecessary list removed 2024-08-15 16:10:56 -07:00
pablodanswer
5f4ac19928 robustify typing 2024-08-15 16:10:56 -07:00
pablodanswer
d898e4f738 remove logs 2024-08-15 16:10:56 -07:00
pablodanswer
19412f0aa0 add ChatState for more robust handling 2024-08-15 16:10:56 -07:00
pablodanswer
c338de30fd add new loading state to prevent collisions 2024-08-15 16:10:56 -07:00
pablodanswer
edfde621b9 formatting 2024-08-15 16:10:56 -07:00
pablodanswer
9306abf911 migrate to streaming response 2024-08-15 16:10:56 -07:00
pablodanswer
70d885b621 cleaner loop + data persistence 2024-08-15 16:10:56 -07:00
pablodanswer
53bea4f859 robustify frontend handling 2024-08-15 16:10:55 -07:00
pablodanswer
a79d734d96 typing 2024-08-15 16:10:28 -07:00
pablodanswer
25cd7de147 remove logs 2024-08-15 16:10:28 -07:00
pablodanswer
ab2916c807 robustify switching 2024-08-15 16:10:28 -07:00
pablodanswer
96112f1f95 functional rework of temporary user/assistant ID 2024-08-15 16:10:28 -07:00
pablodanswer
54502b32d3 remove logs 2024-08-15 16:10:28 -07:00
pablodanswer
9431e6c06c remove commits 2024-08-15 16:10:28 -07:00
pablodanswer
f18571d580 functional types + sidebar 2024-08-15 16:10:28 -07:00
814 changed files with 16616 additions and 70975 deletions

View File

@@ -1,76 +0,0 @@
name: 'Build and Push Docker Image with Retry'
description: 'Attempts to build and push a Docker image, with a retry on failure'
inputs:
context:
description: 'Build context'
required: true
file:
description: 'Dockerfile location'
required: true
platforms:
description: 'Target platforms'
required: true
pull:
description: 'Always attempt to pull a newer version of the image'
required: false
default: 'true'
push:
description: 'Push the image to registry'
required: false
default: 'true'
load:
description: 'Load the image into Docker daemon'
required: false
default: 'true'
tags:
description: 'Image tags'
required: true
cache-from:
description: 'Cache sources'
required: false
cache-to:
description: 'Cache destinations'
required: false
retry-wait-time:
description: 'Time to wait before retry in seconds'
required: false
default: '5'
runs:
using: "composite"
steps:
- name: Build and push Docker image (First Attempt)
id: buildx1
uses: docker/build-push-action@v5
continue-on-error: true
with:
context: ${{ inputs.context }}
file: ${{ inputs.file }}
platforms: ${{ inputs.platforms }}
pull: ${{ inputs.pull }}
push: ${{ inputs.push }}
load: ${{ inputs.load }}
tags: ${{ inputs.tags }}
cache-from: ${{ inputs.cache-from }}
cache-to: ${{ inputs.cache-to }}
- name: Wait to retry
if: steps.buildx1.outcome != 'success'
run: |
echo "First attempt failed. Waiting ${{ inputs.retry-wait-time }} seconds before retry..."
sleep ${{ inputs.retry-wait-time }}
shell: bash
- name: Build and push Docker image (Retry Attempt)
if: steps.buildx1.outcome != 'success'
uses: docker/build-push-action@v5
with:
context: ${{ inputs.context }}
file: ${{ inputs.file }}
platforms: ${{ inputs.platforms }}
pull: ${{ inputs.pull }}
push: ${{ inputs.push }}
load: ${{ inputs.load }}
tags: ${{ inputs.tags }}
cache-from: ${{ inputs.cache-from }}
cache-to: ${{ inputs.cache-to }}

View File

@@ -0,0 +1,33 @@
name: Build Backend Image on Merge Group
on:
merge_group:
types: [checks_requested]
env:
REGISTRY_IMAGE: danswer/danswer-backend
jobs:
build:
# TODO: make this a matrix build like the web containers
runs-on:
group: amd64-image-builders
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Backend Image Docker Build
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64,linux/arm64
push: false
tags: |
${{ env.REGISTRY_IMAGE }}:latest
build-args: |
DANSWER_VERSION=v0.0.1

View File

@@ -7,17 +7,16 @@ on:
env:
REGISTRY_IMAGE: danswer/danswer-backend
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
# TODO: investigate a matrix build like the web container
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
# TODO: make this a matrix build like the web containers
runs-on:
group: amd64-image-builders
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -28,11 +27,6 @@ jobs:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Install build-essential
run: |
sudo apt-get update
sudo apt-get install -y build-essential
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v5
with:
@@ -42,20 +36,12 @@ jobs:
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
${{ env.REGISTRY_IMAGE }}:latest
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}

View File

@@ -5,18 +5,14 @@ on:
tags:
- '*'
env:
REGISTRY_IMAGE: danswer/danswer-model-server
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
runs-on:
group: amd64-image-builders
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -35,21 +31,13 @@ jobs:
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
danswer/danswer-model-server:${{ github.ref_name }}
danswer/danswer-model-server:latest
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'

View File

@@ -7,8 +7,7 @@ on:
env:
REGISTRY_IMAGE: danswer/danswer-web-server
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build:
runs-on:
@@ -36,7 +35,7 @@ jobs:
images: ${{ env.REGISTRY_IMAGE }}
tags: |
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
type=raw,value=${{ env.REGISTRY_IMAGE }}:latest
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -113,16 +112,8 @@ jobs:
run: |
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'

View File

@@ -0,0 +1,53 @@
name: Build Web Image on Merge Group
on:
merge_group:
types: [checks_requested]
env:
REGISTRY_IMAGE: danswer/danswer-web-server
jobs:
build:
runs-on:
group: ${{ matrix.platform == 'linux/amd64' && 'amd64-image-builders' || 'arm64-image-builders' }}
strategy:
fail-fast: false
matrix:
platform:
- linux/amd64
- linux/arm64
steps:
- name: Prepare
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
tags: |
type=raw,value=${{ env.REGISTRY_IMAGE }}:latest
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build by digest
id: build
uses: docker/build-push-action@v5
with:
context: ./web
file: ./web/Dockerfile
platforms: ${{ matrix.platform }}
push: false
build-args: |
DANSWER_VERSION=v0.0.1
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -1,6 +1,3 @@
# This workflow is set up to be manually triggered via the GitHub Action tab.
# Given a version, it will tag those backend and webserver images as "latest".
name: Tag Latest Version
on:
@@ -12,9 +9,7 @@ on:
jobs:
tag:
# See https://runs-on.com/runners/linux/
# use a lower powered instance since this just does i/o to docker hub
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
runs-on: ubuntu-latest
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1

View File

@@ -1,178 +0,0 @@
name: Run Integration Tests v2
concurrency:
group: Run-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
merge_group:
pull_request:
branches:
- main
- 'release/**'
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
jobs:
integration-tests:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# We don't need to build the Web Docker image since it's not yet used
# in the integration tests. We have a separate action to verify that it builds
# successfully.
- name: Pull Web Docker image
run: |
docker pull danswer/danswer-web-server:latest
docker tag danswer/danswer-web-server:latest danswer/danswer-web-server:test
# we use the runs-on cache for docker builds
# in conjunction with runs-on runners, it has better speed and unlimited caching
# https://runs-on.com/caching/s3-cache-for-github-actions/
# https://runs-on.com/caching/docker/
# https://github.com/moby/buildkit#s3-cache-experimental
# images are built and run locally for testing purposes. Not pushed.
- name: Build Backend Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64
tags: danswer/danswer-backend:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Model Server Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64
tags: danswer/danswer-model-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build integration test Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/tests/integration/Dockerfile
platforms: linux/amd64
tags: danswer/danswer-integration:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
docker logs -f danswer-stack-api_server-1 &
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
echo "Finished waiting for service."
- name: Run integration tests
run: |
echo "Running integration tests..."
docker run --rm --network danswer-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
danswer/danswer-integration:test
continue-on-error: true
id: run_tests
- name: Check test results
run: |
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All integration tests passed successfully."
fi
- name: Save Docker logs
if: success() || failure()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@v3
with:
name: docker-logs
path: ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v

View File

@@ -1,68 +0,0 @@
# This workflow is intentionally disabled while we're still working on it
# It's close to ready, but a race condition needs to be fixed with
# API server and Vespa startup, and it needs to have a way to build/test against
# local containers
name: Helm - Lint and Test Charts
on:
merge_group:
pull_request:
branches: [ main ]
jobs:
lint-test:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
# fetch-depth 0 is required for helm/chart-testing-action
steps:
- name: Checkout code
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Set up Helm
uses: azure/setup-helm@v4.2.0
with:
version: v3.14.4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.6.1
- name: Run chart-testing (list-changed)
id: list-changed
run: |
changed=$(ct list-changed --target-branch ${{ github.event.repository.default_branch }})
if [[ -n "$changed" ]]; then
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
- name: Run chart-testing (lint)
# if: steps.list-changed.outputs.changed == 'true'
run: ct lint --all --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
- name: Create kind cluster
# if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@v1.10.0
- name: Run chart-testing (install)
# if: steps.list-changed.outputs.changed == 'true'
run: ct install --all --config ct.yaml
# run: ct install --target-branch ${{ github.event.repository.default_branch }}

View File

@@ -1,16 +1,12 @@
name: Python Checks
on:
merge_group:
pull_request:
branches:
- main
- 'release/**'
branches: [ main ]
jobs:
mypy-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
runs-on: ubuntu-latest
steps:
- name: Checkout code
@@ -27,9 +23,9 @@ jobs:
backend/requirements/model_server.txt
- run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
pip install -r backend/requirements/default.txt
pip install -r backend/requirements/dev.txt
pip install -r backend/requirements/model_server.txt
- name: Run MyPy
run: |

View File

@@ -1,61 +0,0 @@
name: Connector Tests
on:
pull_request:
branches: [main]
schedule:
# This cron expression runs the job daily at 16:00 UTC (9am PT)
- cron: "0 16 * * *"
env:
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
# Jira
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
env:
PYTHONPATH: ./backend
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
cache: "pip"
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors
- name: Alert on Failure
if: failure() && github.event_name == 'schedule'
env:
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
run: |
curl -X POST \
-H 'Content-type: application/json' \
--data '{"text":"Scheduled Connector Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
$SLACK_WEBHOOK

View File

@@ -1,58 +0,0 @@
name: Connector Tests
on:
schedule:
# This cron expression runs the job daily at 16:00 UTC (9am PT)
- cron: "0 16 * * *"
env:
# Bedrock
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
# OpenAI
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
env:
PYTHONPATH: ./backend
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
cache: "pip"
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: |
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/llm
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/embedding
- name: Alert on Failure
if: failure() && github.event_name == 'schedule'
env:
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
run: |
curl -X POST \
-H 'Content-type: application/json' \
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
$SLACK_WEBHOOK

View File

@@ -1,21 +1,16 @@
name: Python Unit Tests
on:
merge_group:
pull_request:
branches:
- main
- 'release/**'
branches: [ main ]
jobs:
backend-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
runs-on: ubuntu-latest
env:
PYTHONPATH: ./backend
REDIS_CLOUD_PYTEST_PASSWORD: ${{ secrets.REDIS_CLOUD_PYTEST_PASSWORD }}
steps:
- name: Checkout code
uses: actions/checkout@v4
@@ -32,8 +27,8 @@ jobs:
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install -r backend/requirements/default.txt
pip install -r backend/requirements/dev.txt
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"

View File

@@ -1,23 +1,21 @@
name: Quality Checks PR
concurrency:
group: Quality-Checks-PR-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
group: Quality-Checks-PR-${{ github.head_ref }}
cancel-in-progress: true
on:
merge_group:
pull_request: null
jobs:
quality-checks:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.0
with:
extra_args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || '' }}
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/setup-python@v5
with:
python-version: '3.11'
- uses: pre-commit/action@v3.0.0
with:
extra_args: --from-ref ${{ github.event.pull_request.base.sha }} --to-ref ${{ github.event.pull_request.head.sha }}

View File

@@ -1,54 +0,0 @@
name: Nightly Tag Push
on:
schedule:
- cron: '0 10 * * *' # Runs every day at 2 AM PST / 3 AM PDT / 10 AM UTC
permissions:
contents: write # Allows pushing tags to the repository
jobs:
create-and-push-tag:
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
# actions using GITHUB_TOKEN cannot trigger another workflow, but we do want this to trigger docker pushes
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
# implement here which needs an actual user's deploy key
- name: Checkout code
uses: actions/checkout@v4
with:
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
- name: Set up Git user
run: |
git config user.name "Richard Kuo [bot]"
git config user.email "rkuo[bot]@danswer.ai"
- name: Check for existing nightly tag
id: check_tag
run: |
if git tag --points-at HEAD --list "nightly-latest*" | grep -q .; then
echo "A tag starting with 'nightly-latest' already exists on HEAD."
echo "tag_exists=true" >> $GITHUB_OUTPUT
else
echo "No tag starting with 'nightly-latest' exists on HEAD."
echo "tag_exists=false" >> $GITHUB_OUTPUT
fi
# don't tag again if HEAD already has a nightly-latest tag on it
- name: Create Nightly Tag
if: steps.check_tag.outputs.tag_exists == 'false'
env:
DATE: ${{ github.run_id }}
run: |
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
echo "Creating tag: $TAG_NAME"
git tag $TAG_NAME
- name: Push Tag
if: steps.check_tag.outputs.tag_exists == 'false'
run: |
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
git push origin $TAG_NAME

2
.gitignore vendored
View File

@@ -4,6 +4,6 @@
.mypy_cache
.idea
/deployment/data/nginx/app.conf
.vscode/
.vscode/launch.json
*.sw?
/backend/tests/regression/answer_quality/search_test_config.yaml

View File

@@ -1 +0,0 @@
backend/tests/integration/tests/pruning/website

View File

@@ -1,5 +1,5 @@
# Copy this file to .env in the .vscode folder
# Fill in the <REPLACE THIS> values as needed, it is recommended to set the GEN_AI_API_KEY value to avoid having to set up an LLM in the UI
# Copy this file to .env at the base of the repo and fill in the <REPLACE THIS> values
# This will help with development iteration speed and reduce repeat tasks for dev
# Also check out danswer/backend/scripts/restart_containers.sh for a script to restart the containers which Danswer relies on outside of VSCode/Cursor processes
# For local dev, often user Authentication is not needed
@@ -15,7 +15,7 @@ LOG_LEVEL=debug
# This passes top N results to LLM an additional time for reranking prior to answer generation
# This step is quite heavy on token usage so we disable it for dev generally
DISABLE_LLM_DOC_RELEVANCE=False
DISABLE_LLM_DOC_RELEVANCE=True
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)
@@ -27,9 +27,9 @@ REQUIRE_EMAIL_VERIFICATION=False
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
GEN_AI_API_KEY=<REPLACE THIS>
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper
GEN_AI_MODEL_VERSION=gpt-4o
FAST_GEN_AI_MODEL_VERSION=gpt-4o
# If answer quality isn't important for dev, use 3.5 turbo due to it being cheaper
GEN_AI_MODEL_VERSION=gpt-3.5-turbo
FAST_GEN_AI_MODEL_VERSION=gpt-3.5-turbo
# For Danswer Slack Bot, overrides the UI values so no need to set this up via UI every time
# Only needed if using DanswerBot
@@ -38,7 +38,7 @@ FAST_GEN_AI_MODEL_VERSION=gpt-4o
# Python stuff
PYTHONPATH=../backend
PYTHONPATH=./backend
PYTHONUNBUFFERED=1
@@ -49,3 +49,4 @@ BING_API_KEY=<REPLACE THIS>
# Enable the full set of Danswer Enterprise Edition features
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False

View File

@@ -1,23 +1,15 @@
/* Copy this file into '.vscode/launch.json' or merge its contents into your existing configurations. */
/*
Copy this file into '.vscode/launch.json' or merge its
contents into your existing configurations.
*/
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"compounds": [
{
"name": "Run All Danswer Services",
"configurations": [
"Web Server",
"Model Server",
"API Server",
"Indexing",
"Background Jobs",
"Slack Bot"
]
}
],
"configurations": [
{
"name": "Web Server",
@@ -25,7 +17,7 @@
"request": "launch",
"cwd": "${workspaceRoot}/web",
"runtimeExecutable": "npm",
"envFile": "${workspaceFolder}/.vscode/.env",
"envFile": "${workspaceFolder}/.env",
"runtimeArgs": [
"run", "dev"
],
@@ -33,12 +25,11 @@
},
{
"name": "Model Server",
"consoleName": "Model Server",
"type": "debugpy",
"type": "python",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
@@ -48,16 +39,16 @@
"--reload",
"--port",
"9000"
]
],
"consoleTitle": "Model Server"
},
{
"name": "API Server",
"consoleName": "API Server",
"type": "debugpy",
"type": "python",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
@@ -68,32 +59,32 @@
"--reload",
"--port",
"8080"
]
],
"consoleTitle": "API Server"
},
{
"name": "Indexing",
"consoleName": "Indexing",
"type": "debugpy",
"type": "python",
"request": "launch",
"program": "danswer/background/update.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"envFile": "${workspaceFolder}/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
}
},
"consoleTitle": "Indexing"
},
// Celery and all async jobs, usually would include indexing as well but this is handled separately above for dev
{
"name": "Background Jobs",
"consoleName": "Background Jobs",
"type": "debugpy",
"type": "python",
"request": "launch",
"program": "scripts/dev_run_background_jobs.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
@@ -102,18 +93,18 @@
},
"args": [
"--no-indexing"
]
],
"consoleTitle": "Background Jobs"
},
// For the listner to access the Slack API,
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
{
"name": "Slack Bot",
"consoleName": "Slack Bot",
"type": "debugpy",
"type": "python",
"request": "launch",
"program": "danswer/danswerbot/slack/listener.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
@@ -122,12 +113,11 @@
},
{
"name": "Pytest",
"consoleName": "Pytest",
"type": "debugpy",
"type": "python",
"request": "launch",
"module": "pytest",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
@@ -138,16 +128,18 @@
// Specify a sepcific module/test to run or provide nothing to run all tests
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
]
},
}
],
"compounds": [
{
"name": "Clear and Restart External Volumes and Containers",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"stopOnEntry": true
"name": "Run Danswer",
"configurations": [
"Web Server",
"Model Server",
"API Server",
"Indexing",
"Background Jobs",
]
}
]
}

View File

@@ -22,7 +22,7 @@ Your input is vital to making sure that Danswer moves in the right direction.
Before starting on implementation, please raise a GitHub issue.
And always feel free to message us (Chris Weaver / Yuhong Sun) on
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ) /
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-2afut44lv-Rw3kSWu6_OmdAXRpCv80DQ) /
[Discord](https://discord.gg/TDJ59cGV2X) directly about anything at all.
@@ -48,26 +48,23 @@ We would love to see you there!
## Get Started 🚀
Danswer being a fully functional app, relies on some external software, specifically:
Danswer being a fully functional app, relies on some external pieces of software, specifically:
- [Postgres](https://www.postgresql.org/) (Relational DB)
- [Vespa](https://vespa.ai/) (Vector DB/Search Engine)
- [Redis](https://redis.io/) (Cache)
- [Nginx](https://nginx.org/) (Not needed for development flows generally)
> **Note:**
> This guide provides instructions to build and run Danswer locally from source with Docker containers providing the above external software. We believe this combination is easier for
> development purposes. If you prefer to use pre-built container images, we provide instructions on running the full Danswer stack within Docker below.
This guide provides instructions to set up the Danswer specific services outside of Docker because it's easier for
development purposes but also feel free to just use the containers and update with local changes by providing the
`--build` flag.
### Local Set Up
Be sure to use Python version 3.11. For instructions on installing Python 3.11 on macOS, refer to the [CONTRIBUTING_MACOS.md](./CONTRIBUTING_MACOS.md) readme.
It is recommended to use Python version 3.11
If using a lower version, modifications will have to be made to the code.
If using a higher version, sometimes some libraries will not be available (i.e. we had problems with Tensorflow in the past with higher versions of python).
If using a higher version, the version of Tensorflow we use may not be available for your platform.
#### Backend: Python requirements
#### Installing Requirements
Currently, we use pip and recommend creating a virtual environment.
For convenience here's a command for it:
@@ -76,9 +73,8 @@ python -m venv .venv
source .venv/bin/activate
```
> **Note:**
> This virtual environment MUST NOT be set up WITHIN the danswer directory if you plan on using mypy within certain IDEs.
> For simplicity, we recommend setting up the virtual environment outside of the danswer directory.
--> Note that this virtual environment MUST NOT be set up WITHIN the danswer
directory
_For Windows, activate the virtual environment using Command Prompt:_
```bash
@@ -93,38 +89,34 @@ Install the required python dependencies:
```bash
pip install -r danswer/backend/requirements/default.txt
pip install -r danswer/backend/requirements/dev.txt
pip install -r danswer/backend/requirements/ee.txt
pip install -r danswer/backend/requirements/model_server.txt
```
Install Playwright for Python (headless browser required by the Web Connector)
In the activated Python virtualenv, install Playwright for Python by running:
```bash
playwright install
```
You may have to deactivate and reactivate your virtualenv for `playwright` to appear on your path.
#### Frontend: Node dependencies
Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend.
Once the above is done, navigate to `danswer/web` run:
```bash
npm i
```
#### Docker containers for external software
You will need Docker installed to run these containers.
Install Playwright (required by the Web Connector)
First navigate to `danswer/deployment/docker_compose`, then start up Postgres/Vespa/Redis with:
> Note: If you have just done the pip install, open a new terminal and source the python virtual-env again.
This will update the path to include playwright
Then install Playwright by running:
```bash
docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational_db cache
playwright install
```
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
#### Running Danswer locally
#### Dependent Docker Containers
First navigate to `danswer/deployment/docker_compose`, then start up Vespa and Postgres with:
```bash
docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational_db
```
(index refers to Vespa and relational_db refers to Postgres)
#### Running Danswer
To start the frontend, navigate to `danswer/web` and run:
```bash
npm run dev
@@ -135,10 +127,11 @@ Navigate to `danswer/backend` and run:
```bash
uvicorn model_server.main:app --reload --port 9000
```
_For Windows (for compatibility with both PowerShell and Command Prompt):_
```bash
powershell -Command "uvicorn model_server.main:app --reload --port 9000"
powershell -Command "
uvicorn model_server.main:app --reload --port 9000
"
```
The first time running Danswer, you will need to run the DB migrations for Postgres.
@@ -161,7 +154,6 @@ To run the backend API server, navigate back to `danswer/backend` and run:
```bash
AUTH_TYPE=disabled uvicorn danswer.main:app --reload --port 8080
```
_For Windows (for compatibility with both PowerShell and Command Prompt):_
```bash
powershell -Command "
@@ -170,58 +162,20 @@ powershell -Command "
"
```
> **Note:**
> If you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services.
#### Wrapping up
You should now have 4 servers running:
- Web server
- Backend API
- Model server
- Background jobs
Now, visit `http://localhost:3000` in your browser. You should see the Danswer onboarding wizard where you can connect your external LLM provider to Danswer.
You've successfully set up a local Danswer instance! 🏁
#### Running the Danswer application in a container
You can run the full Danswer application stack from pre-built images including all external software dependencies.
Navigate to `danswer/deployment/docker_compose` and run:
```bash
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
```
After Docker pulls and starts these containers, navigate to `http://localhost:3000` to use Danswer.
If you want to make changes to Danswer and run those changes in Docker, you can also build a local version of the Danswer container images that incorporates your changes like so:
```bash
docker compose -f docker-compose.dev.yml -p danswer-stack up -d --build
```
Note: if you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services.
### Formatting and Linting
#### Backend
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
First, install pre-commit (if you don't have it already) following the instructions
[here](https://pre-commit.com/#installation).
With the virtual environment active, install the pre-commit library with:
```bash
pip install pre-commit
```
Then, from the `danswer/backend` directory, run:
```bash
pre-commit install
```
Additionally, we use `mypy` for static type checking.
Danswer is fully type-annotated, and we want to keep it that way!
Danswer is fully type-annotated, and we would like to keep it that way!
To run the mypy checks manually, run `python -m mypy .` from the `danswer/backend` directory.
@@ -232,7 +186,6 @@ Please double check that prettier passes before creating a pull request.
### Release Process
Danswer loosely follows the SemVer versioning standard.
Major changes are released with a "minor" version bump. Currently we use patch release versions to indicate small feature changes.
Danswer follows the semver versioning standard.
A set of Docker containers will be pushed automatically to DockerHub with every tag.
You can see the containers [here](https://hub.docker.com/search?q=danswer%2F).

View File

@@ -1,31 +0,0 @@
## Some additional notes for Mac Users
The base instructions to set up the development environment are located in [CONTRIBUTING.md](https://github.com/danswer-ai/danswer/blob/main/CONTRIBUTING.md).
### Setting up Python
Ensure [Homebrew](https://brew.sh/) is already set up.
Then install python 3.11.
```bash
brew install python@3.11
```
Add python 3.11 to your path: add the following line to ~/.zshrc
```
export PATH="$(brew --prefix)/opt/python@3.11/libexec/bin:$PATH"
```
> **Note:**
> You will need to open a new terminal for the path change above to take effect.
### Setting up Docker
On macOS, you will need to install [Docker Desktop](https://www.docker.com/products/docker-desktop/) and
ensure it is running before continuing with the docker commands.
### Formatting and Linting
MacOS will likely require you to remove some quarantine attributes on some of the hooks for them to execute properly.
After installing pre-commit, run the following command:
```bash
sudo xattr -r -d com.apple.quarantine ~/.cache/pre-commit
```

View File

@@ -9,8 +9,7 @@ founders@danswer.ai for more information. Please visit https://github.com/danswe
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
ARG DANSWER_VERSION=0.3-dev
ENV DANSWER_VERSION=${DANSWER_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true"
ENV DANSWER_VERSION=${DANSWER_VERSION}
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
# Install system dependencies
@@ -41,8 +40,6 @@ RUN apt-get update && \
COPY ./requirements/default.txt /tmp/requirements.txt
COPY ./requirements/ee.txt /tmp/ee-requirements.txt
RUN pip install --no-cache-dir --upgrade \
--retries 5 \
--timeout 30 \
-r /tmp/requirements.txt \
-r /tmp/ee-requirements.txt && \
pip uninstall -y py && \
@@ -78,8 +75,8 @@ Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
# Pre-downloading NLTK for setups with limited egress
RUN python -c "import nltk; \
nltk.download('stopwords', quiet=True); \
nltk.download('wordnet', quiet=True); \
nltk.download('punkt', quiet=True);"
# nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed
# Set up application files
WORKDIR /app

View File

@@ -8,17 +8,11 @@ visit https://github.com/danswer-ai/danswer."
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
ARG DANSWER_VERSION=0.3-dev
ENV DANSWER_VERSION=${DANSWER_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true"
ENV DANSWER_VERSION=${DANSWER_VERSION}
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
COPY ./requirements/model_server.txt /tmp/requirements.txt
RUN pip install --no-cache-dir --upgrade \
--retries 5 \
--timeout 30 \
-r /tmp/requirements.txt
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt
RUN apt-get remove -y --allow-remove-essential perl-base && \
apt-get autoremove -y
@@ -28,18 +22,14 @@ RUN apt-get remove -y --allow-remove-essential perl-base && \
# Download model weights
# Run Nomic to pull in the custom architecture and have it cached locally
RUN python -c "from transformers import AutoTokenizer; \
AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
AutoTokenizer.from_pretrained('distilbert-base-uncased', cache_folder='/root/.cache/temp_huggingface/hub/'); \
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1', cache_folder='/root/.cache/temp_huggingface/hub/'); \
from huggingface_hub import snapshot_download; \
snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3'); \
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3', cache_dir='/root/.cache/temp_huggingface/hub/'); \
snapshot_download('nomic-ai/nomic-embed-text-v1', cache_dir='/root/.cache/temp_huggingface/hub/'); \
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1', cache_dir='/root/.cache/temp_huggingface/hub/'); \
from sentence_transformers import SentenceTransformer; \
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);"
# In case the user has volumes mounted to /root/.cache/huggingface that they've downloaded while
# running Danswer, don't overwrite it with the built in cache folder
RUN mv /root/.cache/huggingface /root/.cache/temp_huggingface
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True, cache_folder='/root/.cache/temp_huggingface/hub/');"
WORKDIR /app

View File

@@ -8,53 +8,26 @@ from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine
from celery.backends.database.session import ResultModelBase # type: ignore
from sqlalchemy.schema import SchemaItem
from sqlalchemy.sql import text
# Alembic Config object
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None and config.attributes.get(
"configure_logger", True
):
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# Add your model's MetaData object here
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = [Base.metadata, ResultModelBase.metadata]
def get_schema_options() -> tuple[str, bool]:
x_args_raw = context.get_x_argument()
x_args = {}
for arg in x_args_raw:
for pair in arg.split(","):
if "=" in pair:
key, value = pair.split("=", 1)
x_args[key] = value
schema_name = x_args.get("schema", "public")
create_schema = x_args.get("create_schema", "true").lower() == "true"
return schema_name, create_schema
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
def include_object(
object: SchemaItem,
name: str,
type_: str,
reflected: bool,
compare_to: SchemaItem | None,
) -> bool:
if type_ == "table" and name in EXCLUDE_TABLES:
return False
return True
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
@@ -64,20 +37,17 @@ def run_migrations_offline() -> None:
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = build_connection_string()
schema, _ = get_schema_options()
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
include_object=include_object,
dialect_opts={"paramstyle": "named"},
version_table_schema=schema,
include_schemas=True,
)
with context.begin_transaction():
@@ -85,28 +55,18 @@ def run_migrations_offline() -> None:
def do_run_migrations(connection: Connection) -> None:
schema, create_schema = get_schema_options()
if create_schema:
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"'))
connection.execute(text("COMMIT"))
connection.execute(text(f'SET search_path TO "{schema}"'))
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore
version_table_schema=schema,
include_schemas=True,
compare_type=True,
compare_server_default=True,
)
context.configure(connection=connection, target_metadata=target_metadata) # type: ignore
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
"""Run migrations in 'online' mode."""
"""In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = create_async_engine(
build_connection_string(),
poolclass=pool.NullPool,
@@ -120,6 +80,7 @@ async def run_async_migrations() -> None:
def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
asyncio.run(run_async_migrations())

View File

@@ -1,27 +0,0 @@
"""add ccpair deletion failure message
Revision ID: 0ebb1d516877
Revises: 52a219fb5233
Create Date: 2024-09-10 15:03:48.233926
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0ebb1d516877"
down_revision = "52a219fb5233"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column("deletion_failure_message", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "deletion_failure_message")

View File

@@ -1,102 +0,0 @@
"""add_user_delete_cascades
Revision ID: 1b8206b29c5d
Revises: 35e6853a51d5
Create Date: 2024-09-18 11:48:59.418726
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "1b8206b29c5d"
down_revision = "35e6853a51d5"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_constraint("credential_user_id_fkey", "credential", type_="foreignkey")
op.create_foreign_key(
"credential_user_id_fkey",
"credential",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
op.drop_constraint("chat_session_user_id_fkey", "chat_session", type_="foreignkey")
op.create_foreign_key(
"chat_session_user_id_fkey",
"chat_session",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
op.drop_constraint("chat_folder_user_id_fkey", "chat_folder", type_="foreignkey")
op.create_foreign_key(
"chat_folder_user_id_fkey",
"chat_folder",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
op.drop_constraint("prompt_user_id_fkey", "prompt", type_="foreignkey")
op.create_foreign_key(
"prompt_user_id_fkey", "prompt", "user", ["user_id"], ["id"], ondelete="CASCADE"
)
op.drop_constraint("notification_user_id_fkey", "notification", type_="foreignkey")
op.create_foreign_key(
"notification_user_id_fkey",
"notification",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
op.drop_constraint("inputprompt_user_id_fkey", "inputprompt", type_="foreignkey")
op.create_foreign_key(
"inputprompt_user_id_fkey",
"inputprompt",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
def downgrade() -> None:
op.drop_constraint("credential_user_id_fkey", "credential", type_="foreignkey")
op.create_foreign_key(
"credential_user_id_fkey", "credential", "user", ["user_id"], ["id"]
)
op.drop_constraint("chat_session_user_id_fkey", "chat_session", type_="foreignkey")
op.create_foreign_key(
"chat_session_user_id_fkey", "chat_session", "user", ["user_id"], ["id"]
)
op.drop_constraint("chat_folder_user_id_fkey", "chat_folder", type_="foreignkey")
op.create_foreign_key(
"chat_folder_user_id_fkey", "chat_folder", "user", ["user_id"], ["id"]
)
op.drop_constraint("prompt_user_id_fkey", "prompt", type_="foreignkey")
op.create_foreign_key("prompt_user_id_fkey", "prompt", "user", ["user_id"], ["id"])
op.drop_constraint("notification_user_id_fkey", "notification", type_="foreignkey")
op.create_foreign_key(
"notification_user_id_fkey", "notification", "user", ["user_id"], ["id"]
)
op.drop_constraint("inputprompt_user_id_fkey", "inputprompt", type_="foreignkey")
op.create_foreign_key(
"inputprompt_user_id_fkey", "inputprompt", "user", ["user_id"], ["id"]
)

View File

@@ -1,135 +0,0 @@
"""embedding model -> search settings
Revision ID: 1f60f60c3401
Revises: f17bf3b0d9f1
Create Date: 2024-08-25 12:39:51.731632
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
# revision identifiers, used by Alembic.
revision = "1f60f60c3401"
down_revision = "f17bf3b0d9f1"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.drop_constraint(
"index_attempt__embedding_model_fk", "index_attempt", type_="foreignkey"
)
# Rename the table
op.rename_table("embedding_model", "search_settings")
# Add new columns
op.add_column(
"search_settings",
sa.Column(
"multipass_indexing", sa.Boolean(), nullable=False, server_default="false"
),
)
op.add_column(
"search_settings",
sa.Column(
"multilingual_expansion",
postgresql.ARRAY(sa.String()),
nullable=False,
server_default="{}",
),
)
op.add_column(
"search_settings",
sa.Column(
"disable_rerank_for_streaming",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
op.add_column(
"search_settings", sa.Column("rerank_model_name", sa.String(), nullable=True)
)
op.add_column(
"search_settings", sa.Column("rerank_provider_type", sa.String(), nullable=True)
)
op.add_column(
"search_settings", sa.Column("rerank_api_key", sa.String(), nullable=True)
)
op.add_column(
"search_settings",
sa.Column(
"num_rerank",
sa.Integer(),
nullable=False,
server_default=str(NUM_POSTPROCESSED_RESULTS),
),
)
# Add the new column as nullable initially
op.add_column(
"index_attempt", sa.Column("search_settings_id", sa.Integer(), nullable=True)
)
# Populate the new column with data from the existing embedding_model_id
op.execute("UPDATE index_attempt SET search_settings_id = embedding_model_id")
# Create the foreign key constraint
op.create_foreign_key(
"fk_index_attempt_search_settings",
"index_attempt",
"search_settings",
["search_settings_id"],
["id"],
)
# Make the new column non-nullable
op.alter_column("index_attempt", "search_settings_id", nullable=False)
# Drop the old embedding_model_id column
op.drop_column("index_attempt", "embedding_model_id")
def downgrade() -> None:
# Add back the embedding_model_id column
op.add_column(
"index_attempt", sa.Column("embedding_model_id", sa.Integer(), nullable=True)
)
# Populate the old column with data from search_settings_id
op.execute("UPDATE index_attempt SET embedding_model_id = search_settings_id")
# Make the old column non-nullable
op.alter_column("index_attempt", "embedding_model_id", nullable=False)
# Drop the foreign key constraint
op.drop_constraint(
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
)
# Drop the new search_settings_id column
op.drop_column("index_attempt", "search_settings_id")
# Rename the table back
op.rename_table("search_settings", "embedding_model")
# Remove added columns
op.drop_column("embedding_model", "num_rerank")
op.drop_column("embedding_model", "rerank_api_key")
op.drop_column("embedding_model", "rerank_provider_type")
op.drop_column("embedding_model", "rerank_model_name")
op.drop_column("embedding_model", "disable_rerank_for_streaming")
op.drop_column("embedding_model", "multilingual_expansion")
op.drop_column("embedding_model", "multipass_indexing")
op.create_foreign_key(
"index_attempt__embedding_model_fk",
"index_attempt",
"embedding_model",
["embedding_model_id"],
["id"],
)

View File

@@ -1,32 +0,0 @@
"""Add Above Below to Persona
Revision ID: 2d2304e27d8c
Revises: 4b08d97e175a
Create Date: 2024-08-21 19:15:15.762948
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2d2304e27d8c"
down_revision = "4b08d97e175a"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column("persona", sa.Column("chunks_above", sa.Integer(), nullable=True))
op.add_column("persona", sa.Column("chunks_below", sa.Integer(), nullable=True))
op.execute(
"UPDATE persona SET chunks_above = 1, chunks_below = 1 WHERE chunks_above IS NULL AND chunks_below IS NULL"
)
op.alter_column("persona", "chunks_above", nullable=False)
op.alter_column("persona", "chunks_below", nullable=False)
def downgrade() -> None:
op.drop_column("persona", "chunks_below")
op.drop_column("persona", "chunks_above")

View File

@@ -1,90 +0,0 @@
"""Add curator fields
Revision ID: 351faebd379d
Revises: ee3f4b47fad5
Create Date: 2024-08-15 22:37:08.397052
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "351faebd379d"
down_revision = "ee3f4b47fad5"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# Add is_curator column to User__UserGroup table
op.add_column(
"user__user_group",
sa.Column("is_curator", sa.Boolean(), nullable=False, server_default="false"),
)
# Use batch mode to modify the enum type
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.alter_column( # type: ignore[attr-defined]
"role",
type_=sa.Enum(
"BASIC",
"ADMIN",
"CURATOR",
"GLOBAL_CURATOR",
name="userrole",
native_enum=False,
),
existing_type=sa.Enum("BASIC", "ADMIN", name="userrole", native_enum=False),
existing_nullable=False,
)
# Create the association table
op.create_table(
"credential__user_group",
sa.Column("credential_id", sa.Integer(), nullable=False),
sa.Column("user_group_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["credential_id"],
["credential.id"],
),
sa.ForeignKeyConstraint(
["user_group_id"],
["user_group.id"],
),
sa.PrimaryKeyConstraint("credential_id", "user_group_id"),
)
op.add_column(
"credential",
sa.Column(
"curator_public", sa.Boolean(), nullable=False, server_default="false"
),
)
def downgrade() -> None:
# Update existing records to ensure they fit within the BASIC/ADMIN roles
op.execute(
"UPDATE \"user\" SET role = 'ADMIN' WHERE role IN ('CURATOR', 'GLOBAL_CURATOR')"
)
# Remove is_curator column from User__UserGroup table
op.drop_column("user__user_group", "is_curator")
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.alter_column( # type: ignore[attr-defined]
"role",
type_=sa.Enum(
"BASIC", "ADMIN", name="userrole", native_enum=False, length=20
),
existing_type=sa.Enum(
"BASIC",
"ADMIN",
"CURATOR",
"GLOBAL_CURATOR",
name="userrole",
native_enum=False,
),
existing_nullable=False,
)
# Drop the association table
op.drop_table("credential__user_group")
op.drop_column("credential", "curator_public")

View File

@@ -1,64 +0,0 @@
"""server default chosen assistants
Revision ID: 35e6853a51d5
Revises: c99d76fcd298
Create Date: 2024-09-13 13:20:32.885317
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "35e6853a51d5"
down_revision = "c99d76fcd298"
branch_labels = None
depends_on = None
DEFAULT_ASSISTANTS = [-2, -1, 0]
def upgrade() -> None:
# Step 1: Update any NULL values to the default value
# This upgrades existing users without ordered assistant
# to have default assistants set to visible assistants which are
# accessible by them.
op.execute(
"""
UPDATE "user" u
SET chosen_assistants = (
SELECT jsonb_agg(
p.id ORDER BY
COALESCE(p.display_priority, 2147483647) ASC,
p.id ASC
)
FROM persona p
LEFT JOIN persona__user pu ON p.id = pu.persona_id AND pu.user_id = u.id
WHERE p.is_visible = true
AND (p.is_public = true OR pu.user_id IS NOT NULL)
)
WHERE chosen_assistants IS NULL
OR chosen_assistants = 'null'
OR jsonb_typeof(chosen_assistants) = 'null'
OR (jsonb_typeof(chosen_assistants) = 'string' AND chosen_assistants = '"null"')
"""
)
# Step 2: Alter the column to make it non-nullable
op.alter_column(
"user",
"chosen_assistants",
type_=postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default=sa.text(f"'{DEFAULT_ASSISTANTS}'::jsonb"),
)
def downgrade() -> None:
op.alter_column(
"user",
"chosen_assistants",
type_=postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
server_default=None,
)

View File

@@ -1,46 +0,0 @@
"""fix_user__external_user_group_id_fk
Revision ID: 46b7a812670f
Revises: f32615f71aeb
Create Date: 2024-09-23 12:58:03.894038
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "46b7a812670f"
down_revision = "f32615f71aeb"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop the existing primary key
op.drop_constraint(
"user__external_user_group_id_pkey",
"user__external_user_group_id",
type_="primary",
)
# Add the new composite primary key
op.create_primary_key(
"user__external_user_group_id_pkey",
"user__external_user_group_id",
["user_id", "external_user_group_id", "cc_pair_id"],
)
def downgrade() -> None:
# Drop the composite primary key
op.drop_constraint(
"user__external_user_group_id_pkey",
"user__external_user_group_id",
type_="primary",
)
# Delete all entries from the table
op.execute("DELETE FROM user__external_user_group_id")
# Recreate the original primary key on user_id
op.create_primary_key(
"user__external_user_group_id_pkey", "user__external_user_group_id", ["user_id"]
)

View File

@@ -1,34 +0,0 @@
"""change default prune_freq
Revision ID: 4b08d97e175a
Revises: d9ec13955951
Create Date: 2024-08-20 15:28:52.993827
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "4b08d97e175a"
down_revision = "d9ec13955951"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.execute(
"""
UPDATE connector
SET prune_freq = 2592000
WHERE prune_freq = 86400
"""
)
def downgrade() -> None:
op.execute(
"""
UPDATE connector
SET prune_freq = 86400
WHERE prune_freq = 2592000
"""
)

View File

@@ -1,66 +0,0 @@
"""Add last synced and last modified to document table
Revision ID: 52a219fb5233
Revises: f7e58d357687
Create Date: 2024-08-28 17:40:46.077470
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import func
# revision identifiers, used by Alembic.
revision = "52a219fb5233"
down_revision = "f7e58d357687"
branch_labels = None
depends_on = None
def upgrade() -> None:
# last modified represents the last time anything needing syncing to vespa changed
# including row metadata and the document itself. This obviously does not include
# the last_synced column.
op.add_column(
"document",
sa.Column(
"last_modified",
sa.DateTime(timezone=True),
nullable=False,
server_default=func.now(),
),
)
# last synced represents the last time this document was synced to Vespa
op.add_column(
"document",
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True),
)
# Set last_synced to the same value as last_modified for existing rows
op.execute(
"""
UPDATE document
SET last_synced = last_modified
"""
)
op.create_index(
op.f("ix_document_last_modified"),
"document",
["last_modified"],
unique=False,
)
op.create_index(
op.f("ix_document_last_synced"),
"document",
["last_synced"],
unique=False,
)
def downgrade() -> None:
op.drop_index(op.f("ix_document_last_synced"), table_name="document")
op.drop_index(op.f("ix_document_last_modified"), table_name="document")
op.drop_column("document", "last_synced")
op.drop_column("document", "last_modified")

View File

@@ -1,79 +0,0 @@
"""assistant_rework
Revision ID: 55546a7967ee
Revises: 61ff3651add4
Create Date: 2024-09-18 17:00:23.755399
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "55546a7967ee"
down_revision = "61ff3651add4"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Reworking persona and user tables for new assistant features
# keep track of user's chosen assistants separate from their `ordering`
op.add_column("persona", sa.Column("builtin_persona", sa.Boolean(), nullable=True))
op.execute("UPDATE persona SET builtin_persona = default_persona")
op.alter_column("persona", "builtin_persona", nullable=False)
op.drop_index("_default_persona_name_idx", table_name="persona")
op.create_index(
"_builtin_persona_name_idx",
"persona",
["name"],
unique=True,
postgresql_where=sa.text("builtin_persona = true"),
)
op.add_column(
"user", sa.Column("visible_assistants", postgresql.JSONB(), nullable=True)
)
op.add_column(
"user", sa.Column("hidden_assistants", postgresql.JSONB(), nullable=True)
)
op.execute(
"UPDATE \"user\" SET visible_assistants = '[]'::jsonb, hidden_assistants = '[]'::jsonb"
)
op.alter_column(
"user",
"visible_assistants",
nullable=False,
server_default=sa.text("'[]'::jsonb"),
)
op.alter_column(
"user",
"hidden_assistants",
nullable=False,
server_default=sa.text("'[]'::jsonb"),
)
op.drop_column("persona", "default_persona")
op.add_column(
"persona", sa.Column("is_default_persona", sa.Boolean(), nullable=True)
)
def downgrade() -> None:
# Reverting changes made in upgrade
op.drop_column("user", "hidden_assistants")
op.drop_column("user", "visible_assistants")
op.drop_index("_builtin_persona_name_idx", table_name="persona")
op.drop_column("persona", "is_default_persona")
op.add_column("persona", sa.Column("default_persona", sa.Boolean(), nullable=True))
op.execute("UPDATE persona SET default_persona = builtin_persona")
op.alter_column("persona", "default_persona", nullable=False)
op.drop_column("persona", "builtin_persona")
op.create_index(
"_default_persona_name_idx",
"persona",
["name"],
unique=True,
postgresql_where=sa.text("default_persona = true"),
)

View File

@@ -1,35 +0,0 @@
"""match_any_keywords flag for standard answers
Revision ID: 5c7fdadae813
Revises: efb35676026c
Create Date: 2024-09-13 18:52:59.256478
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "5c7fdadae813"
down_revision = "efb35676026c"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"standard_answer",
sa.Column(
"match_any_keywords",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("standard_answer", "match_any_keywords")
# ### end Alembic commands ###

View File

@@ -1,162 +0,0 @@
"""Add Permission Syncing
Revision ID: 61ff3651add4
Revises: 1b8206b29c5d
Create Date: 2024-09-05 13:57:11.770413
"""
import fastapi_users_db_sqlalchemy
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "61ff3651add4"
down_revision = "1b8206b29c5d"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Admin user who set up connectors will lose access to the docs temporarily
# only way currently to give back access is to rerun from beginning
op.add_column(
"connector_credential_pair",
sa.Column(
"access_type",
sa.String(),
nullable=True,
),
)
op.execute(
"UPDATE connector_credential_pair SET access_type = 'PUBLIC' WHERE is_public = true"
)
op.execute(
"UPDATE connector_credential_pair SET access_type = 'PRIVATE' WHERE is_public = false"
)
op.alter_column("connector_credential_pair", "access_type", nullable=False)
op.add_column(
"connector_credential_pair",
sa.Column(
"auto_sync_options",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
op.add_column(
"connector_credential_pair",
sa.Column("last_time_perm_sync", sa.DateTime(timezone=True), nullable=True),
)
op.drop_column("connector_credential_pair", "is_public")
op.add_column(
"document",
sa.Column("external_user_emails", postgresql.ARRAY(sa.String()), nullable=True),
)
op.add_column(
"document",
sa.Column(
"external_user_group_ids", postgresql.ARRAY(sa.String()), nullable=True
),
)
op.add_column(
"document",
sa.Column("is_public", sa.Boolean(), nullable=True),
)
op.create_table(
"user__external_user_group_id",
sa.Column(
"user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False
),
sa.Column("external_user_group_id", sa.String(), nullable=False),
sa.Column("cc_pair_id", sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint("user_id"),
)
op.drop_column("external_permission", "user_id")
op.drop_column("email_to_external_user_cache", "user_id")
op.drop_table("permission_sync_run")
op.drop_table("external_permission")
op.drop_table("email_to_external_user_cache")
def downgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column("is_public", sa.BOOLEAN(), nullable=True),
)
op.execute(
"UPDATE connector_credential_pair SET is_public = (access_type = 'PUBLIC')"
)
op.alter_column("connector_credential_pair", "is_public", nullable=False)
op.drop_column("connector_credential_pair", "auto_sync_options")
op.drop_column("connector_credential_pair", "access_type")
op.drop_column("connector_credential_pair", "last_time_perm_sync")
op.drop_column("document", "external_user_emails")
op.drop_column("document", "external_user_group_ids")
op.drop_column("document", "is_public")
op.drop_table("user__external_user_group_id")
# Drop the enum type at the end of the downgrade
op.create_table(
"permission_sync_run",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"source_type",
sa.String(),
nullable=False,
),
sa.Column("update_type", sa.String(), nullable=False),
sa.Column("cc_pair_id", sa.Integer(), nullable=True),
sa.Column(
"status",
sa.String(),
nullable=False,
),
sa.Column("error_msg", sa.Text(), nullable=True),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["cc_pair_id"],
["connector_credential_pair.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"external_permission",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=True),
sa.Column("user_email", sa.String(), nullable=False),
sa.Column(
"source_type",
sa.String(),
nullable=False,
),
sa.Column("external_permission_group", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"email_to_external_user_cache",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("external_user_id", sa.String(), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=True),
sa.Column("user_email", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)

View File

@@ -9,7 +9,7 @@ import json
from typing import cast
from alembic import op
import sqlalchemy as sa
from danswer.key_value_store.factory import get_kv_store
from danswer.dynamic_configs.factory import get_dynamic_config_store
# revision identifiers, used by Alembic.
revision = "703313b75876"
@@ -54,7 +54,9 @@ def upgrade() -> None:
)
try:
settings_json = cast(str, get_kv_store().load("token_budget_settings"))
settings_json = cast(
str, get_dynamic_config_store().load("token_budget_settings")
)
settings = json.loads(settings_json)
is_enabled = settings.get("enable_token_budget", False)
@@ -69,7 +71,7 @@ def upgrade() -> None:
)
# Delete the dynamic config
get_kv_store().delete("token_budget_settings")
get_dynamic_config_store().delete("token_budget_settings")
except Exception:
# Ignore if the dynamic config is not found

View File

@@ -10,7 +10,7 @@ import sqlalchemy as sa
from danswer.db.models import IndexModelStatus
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType
from danswer.search.models import SearchType
# revision identifiers, used by Alembic.
revision = "776b3bbe9092"

View File

@@ -1,27 +0,0 @@
"""persona_start_date
Revision ID: 797089dfb4d2
Revises: 55546a7967ee
Create Date: 2024-09-11 14:51:49.785835
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "797089dfb4d2"
down_revision = "55546a7967ee"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"persona",
sa.Column("search_start_date", sa.DateTime(timezone=True), nullable=True),
)
def downgrade() -> None:
op.drop_column("persona", "search_start_date")

View File

@@ -35,22 +35,18 @@ def upgrade() -> None:
op.execute(
"""
UPDATE index_attempt ia
SET connector_credential_pair_id = (
SELECT id FROM connector_credential_pair ccp
WHERE
(ia.connector_id IS NULL OR ccp.connector_id = ia.connector_id)
AND (ia.credential_id IS NULL OR ccp.credential_id = ia.credential_id)
LIMIT 1
)
WHERE ia.connector_id IS NOT NULL OR ia.credential_id IS NOT NULL
"""
)
# For good measure
op.execute(
"""
DELETE FROM index_attempt
WHERE connector_credential_pair_id IS NULL
SET connector_credential_pair_id =
CASE
WHEN ia.credential_id IS NULL THEN
(SELECT id FROM connector_credential_pair
WHERE connector_id = ia.connector_id
LIMIT 1)
ELSE
(SELECT id FROM connector_credential_pair
WHERE connector_id = ia.connector_id
AND credential_id = ia.credential_id)
END
WHERE ia.connector_id IS NOT NULL
"""
)

View File

@@ -1,158 +0,0 @@
"""migration confluence to be explicit
Revision ID: a3795dce87be
Revises: 1f60f60c3401
Create Date: 2024-09-01 13:52:12.006740
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy.sql import table, column
revision = "a3795dce87be"
down_revision = "1f60f60c3401"
branch_labels: None = None
depends_on: None = None
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]:
from urllib.parse import urlparse
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]:
parsed_url = urlparse(wiki_url)
wiki_base = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split('/spaces')[0]}"
path_parts = parsed_url.path.split("/")
space = path_parts[3]
page_id = path_parts[5] if len(path_parts) > 5 else ""
return wiki_base, space, page_id
def _extract_confluence_keys_from_datacenter_url(
wiki_url: str,
) -> tuple[str, str, str]:
DISPLAY = "/display/"
PAGE = "/pages/"
parsed_url = urlparse(wiki_url)
wiki_base = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split(DISPLAY)[0]}"
space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0]
page_id = ""
if (content := parsed_url.path.split(PAGE)) and len(content) > 1:
page_id = content[1]
return wiki_base, space, page_id
is_confluence_cloud = (
".atlassian.net/wiki/spaces/" in wiki_url
or ".jira.com/wiki/spaces/" in wiki_url
)
if is_confluence_cloud:
wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url(wiki_url)
else:
wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url(
wiki_url
)
return wiki_base, space, page_id, is_confluence_cloud
def reconstruct_confluence_url(
wiki_base: str, space: str, page_id: str, is_cloud: bool
) -> str:
if is_cloud:
url = f"{wiki_base}/spaces/{space}"
if page_id:
url += f"/pages/{page_id}"
else:
url = f"{wiki_base}/display/{space}"
if page_id:
url += f"/pages/{page_id}"
return url
def upgrade() -> None:
connector = table(
"connector",
column("id", sa.Integer),
column("source", sa.String()),
column("input_type", sa.String()),
column("connector_specific_config", postgresql.JSONB),
)
# Fetch all Confluence connectors
connection = op.get_bind()
confluence_connectors = connection.execute(
sa.select(connector).where(
sa.and_(
connector.c.source == "CONFLUENCE", connector.c.input_type == "POLL"
)
)
).fetchall()
for row in confluence_connectors:
config = row.connector_specific_config
wiki_page_url = config["wiki_page_url"]
wiki_base, space, page_id, is_cloud = extract_confluence_keys_from_url(
wiki_page_url
)
new_config = {
"wiki_base": wiki_base,
"space": space,
"page_id": page_id,
"is_cloud": is_cloud,
}
for key, value in config.items():
if key not in ["wiki_page_url"]:
new_config[key] = value
op.execute(
connector.update()
.where(connector.c.id == row.id)
.values(connector_specific_config=new_config)
)
def downgrade() -> None:
connector = table(
"connector",
column("id", sa.Integer),
column("source", sa.String()),
column("input_type", sa.String()),
column("connector_specific_config", postgresql.JSONB),
)
confluence_connectors = (
op.get_bind()
.execute(
sa.select(connector).where(
connector.c.source == "CONFLUENCE", connector.c.input_type == "POLL"
)
)
.fetchall()
)
for row in confluence_connectors:
config = row.connector_specific_config
if all(key in config for key in ["wiki_base", "space", "is_cloud"]):
wiki_page_url = reconstruct_confluence_url(
config["wiki_base"],
config["space"],
config.get("page_id", ""),
config["is_cloud"],
)
new_config = {"wiki_page_url": wiki_page_url}
new_config.update(
{
k: v
for k, v in config.items()
if k not in ["wiki_base", "space", "page_id", "is_cloud"]
}
)
op.execute(
connector.update()
.where(connector.c.id == row.id)
.values(connector_specific_config=new_config)
)

View File

@@ -1,27 +0,0 @@
"""add last_pruned to the connector_credential_pair table
Revision ID: ac5eaac849f9
Revises: 52a219fb5233
Create Date: 2024-09-10 15:04:26.437118
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "ac5eaac849f9"
down_revision = "46b7a812670f"
branch_labels = None
depends_on = None
def upgrade() -> None:
# last pruned represents the last time the connector was pruned
op.add_column(
"connector_credential_pair",
sa.Column("last_pruned", sa.DateTime(timezone=True), nullable=True),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "last_pruned")

View File

@@ -1,26 +0,0 @@
"""add support for litellm proxy in reranking
Revision ID: ba98eba0f66a
Revises: bceb1e139447
Create Date: 2024-09-06 10:36:04.507332
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "ba98eba0f66a"
down_revision = "bceb1e139447"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"search_settings", sa.Column("rerank_api_url", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("search_settings", "rerank_api_url")

View File

@@ -1,26 +0,0 @@
"""Add base_url to CloudEmbeddingProvider
Revision ID: bceb1e139447
Revises: a3795dce87be
Create Date: 2024-08-28 17:00:52.554580
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "bceb1e139447"
down_revision = "a3795dce87be"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"embedding_provider", sa.Column("api_url", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("embedding_provider", "api_url")

View File

@@ -1,43 +0,0 @@
"""non nullable default persona
Revision ID: bd2921608c3a
Revises: 797089dfb4d2
Create Date: 2024-09-20 10:28:37.992042
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "bd2921608c3a"
down_revision = "797089dfb4d2"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Set existing NULL values to False
op.execute(
"UPDATE persona SET is_default_persona = FALSE WHERE is_default_persona IS NULL"
)
# Alter the column to be not nullable with a default value of False
op.alter_column(
"persona",
"is_default_persona",
existing_type=sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
)
def downgrade() -> None:
# Revert the changes
op.alter_column(
"persona",
"is_default_persona",
existing_type=sa.Boolean(),
nullable=True,
server_default=None,
)

View File

@@ -1,57 +0,0 @@
"""Add index_attempt_errors table
Revision ID: c5b692fa265c
Revises: 4a951134c801
Create Date: 2024-08-08 14:06:39.581972
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "c5b692fa265c"
down_revision = "4a951134c801"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"index_attempt_errors",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("index_attempt_id", sa.Integer(), nullable=True),
sa.Column("batch", sa.Integer(), nullable=True),
sa.Column(
"doc_summaries",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
),
sa.Column("error_msg", sa.Text(), nullable=True),
sa.Column("traceback", sa.Text(), nullable=True),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["index_attempt_id"],
["index_attempt.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"index_attempt_id",
"index_attempt_errors",
["time_created"],
unique=False,
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("index_attempt_id", table_name="index_attempt_errors")
op.drop_table("index_attempt_errors")
# ### end Alembic commands ###

View File

@@ -1,31 +0,0 @@
"""add nullable to persona id in Chat Session
Revision ID: c99d76fcd298
Revises: 5c7fdadae813
Create Date: 2024-07-09 19:27:01.579697
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c99d76fcd298"
down_revision = "5c7fdadae813"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column(
"chat_session", "persona_id", existing_type=sa.INTEGER(), nullable=True
)
def downgrade() -> None:
op.alter_column(
"chat_session",
"persona_id",
existing_type=sa.INTEGER(),
nullable=False,
)

View File

@@ -1,31 +0,0 @@
"""Remove _alt suffix from model_name
Revision ID: d9ec13955951
Revises: da4c21c69164
Create Date: 2024-08-20 16:31:32.955686
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "d9ec13955951"
down_revision = "da4c21c69164"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.execute(
"""
UPDATE embedding_model
SET model_name = regexp_replace(model_name, '__danswer_alt_index$', '')
WHERE model_name LIKE '%__danswer_alt_index'
"""
)
def downgrade() -> None:
# We can't reliably add the __danswer_alt_index suffix back, so we'll leave this empty
pass

View File

@@ -1,65 +0,0 @@
"""chosen_assistants changed to jsonb
Revision ID: da4c21c69164
Revises: c5b692fa265c
Create Date: 2024-08-18 19:06:47.291491
"""
import json
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "da4c21c69164"
down_revision = "c5b692fa265c"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
conn = op.get_bind()
existing_ids_and_chosen_assistants = conn.execute(
sa.text("select id, chosen_assistants from public.user")
)
op.drop_column(
"user",
"chosen_assistants",
)
op.add_column(
"user",
sa.Column(
"chosen_assistants",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
for id, chosen_assistants in existing_ids_and_chosen_assistants:
conn.execute(
sa.text(
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
),
{"chosen_assistants": json.dumps(chosen_assistants), "id": id},
)
def downgrade() -> None:
conn = op.get_bind()
existing_ids_and_chosen_assistants = conn.execute(
sa.text("select id, chosen_assistants from public.user")
)
op.drop_column(
"user",
"chosen_assistants",
)
op.add_column(
"user",
sa.Column("chosen_assistants", postgresql.ARRAY(sa.Integer()), nullable=True),
)
for id, chosen_assistants in existing_ids_and_chosen_assistants:
conn.execute(
sa.text(
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
),
{"chosen_assistants": chosen_assistants, "id": id},
)

View File

@@ -9,7 +9,7 @@ from alembic import op
import sqlalchemy as sa
from sqlalchemy import table, column, String, Integer, Boolean
from danswer.db.search_settings import (
from danswer.db.embedding_model import (
get_new_default_embedding_model,
get_old_default_embedding_model,
user_has_overridden_embedding_model,
@@ -71,14 +71,14 @@ def upgrade() -> None:
"query_prefix": old_embedding_model.query_prefix,
"passage_prefix": old_embedding_model.passage_prefix,
"index_name": old_embedding_model.index_name,
"status": IndexModelStatus.PRESENT,
"status": old_embedding_model.status,
}
],
)
# if the user has not overridden the default embedding model via env variables,
# insert the new default model into the database to auto-upgrade them
if not user_has_overridden_embedding_model():
new_embedding_model = get_new_default_embedding_model()
new_embedding_model = get_new_default_embedding_model(is_present=False)
op.bulk_insert(
EmbeddingModel,
[

View File

@@ -0,0 +1,59 @@
"""migrate tool calls
Revision ID: eb690a089310
Revises: ee3f4b47fad5
Create Date: 2024-08-04 17:07:47.533051
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "eb690a089310"
down_revision = "ee3f4b47fad5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create the new column
op.add_column(
"chat_message", sa.Column("tool_call_id", sa.Integer(), nullable=True)
)
op.create_foreign_key(
"fk_chat_message_tool_call",
"chat_message",
"tool_call",
["tool_call_id"],
["id"],
)
# Migrate existing data
op.execute(
"UPDATE chat_message SET tool_call_id = (SELECT id FROM tool_call WHERE tool_call.message_id = chat_message.id LIMIT 1)"
)
# Drop the old relationship
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
op.drop_column("tool_call", "message_id")
def downgrade() -> None:
# Add back the old column
op.add_column(
"tool_call",
sa.Column("message_id", sa.INTEGER(), autoincrement=False, nullable=True),
)
op.create_foreign_key(
"tool_call_message_id_fkey", "tool_call", "chat_message", ["message_id"], ["id"]
)
# Migrate data back
op.execute(
"UPDATE tool_call SET message_id = (SELECT id FROM chat_message WHERE chat_message.tool_call_id = tool_call.id)"
)
# Drop the new column
op.drop_constraint("fk_chat_message_tool_call", "chat_message", type_="foreignkey")
op.drop_column("chat_message", "tool_call_id")

View File

@@ -1,7 +1,7 @@
"""Added alternate model to chat message
Revision ID: ee3f4b47fad5
Revises: 2d2304e27d8c
Revises: 4a951134c801
Create Date: 2024-08-12 00:11:50.915845
"""
@@ -12,17 +12,17 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "ee3f4b47fad5"
down_revision = "2d2304e27d8c"
branch_labels: None = None
depends_on: None = None
down_revision = "4a951134c801"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_message",
sa.Column("overridden_model", sa.String(length=255), nullable=True),
sa.Column("alternate_model", sa.String(length=255), nullable=True),
)
def downgrade() -> None:
op.drop_column("chat_message", "overridden_model")
op.drop_column("chat_message", "alternate_model")

View File

@@ -1,32 +0,0 @@
"""standard answer match_regex flag
Revision ID: efb35676026c
Revises: 0ebb1d516877
Create Date: 2024-09-11 13:55:46.101149
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "efb35676026c"
down_revision = "0ebb1d516877"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"standard_answer",
sa.Column(
"match_regex", sa.Boolean(), nullable=False, server_default=sa.false()
),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("standard_answer", "match_regex")
# ### end Alembic commands ###

View File

@@ -1,172 +0,0 @@
"""embedding provider by provider type
Revision ID: f17bf3b0d9f1
Revises: 351faebd379d
Create Date: 2024-08-21 13:13:31.120460
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f17bf3b0d9f1"
down_revision = "351faebd379d"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# Add provider_type column to embedding_provider
op.add_column(
"embedding_provider",
sa.Column("provider_type", sa.String(50), nullable=True),
)
# Update provider_type with existing name values
op.execute("UPDATE embedding_provider SET provider_type = UPPER(name)")
# Make provider_type not nullable
op.alter_column("embedding_provider", "provider_type", nullable=False)
# Drop the foreign key constraint in embedding_model table
op.drop_constraint(
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
)
# Drop the existing primary key constraint
op.drop_constraint("embedding_provider_pkey", "embedding_provider", type_="primary")
# Create a new primary key constraint on provider_type
op.create_primary_key(
"embedding_provider_pkey", "embedding_provider", ["provider_type"]
)
# Add provider_type column to embedding_model
op.add_column(
"embedding_model",
sa.Column("provider_type", sa.String(50), nullable=True),
)
# Update provider_type for existing embedding models
op.execute(
"""
UPDATE embedding_model
SET provider_type = (
SELECT provider_type
FROM embedding_provider
WHERE embedding_provider.id = embedding_model.cloud_provider_id
)
"""
)
# Drop the old id column from embedding_provider
op.drop_column("embedding_provider", "id")
# Drop the name column from embedding_provider
op.drop_column("embedding_provider", "name")
# Drop the default_model_id column from embedding_provider
op.drop_column("embedding_provider", "default_model_id")
# Drop the old cloud_provider_id column from embedding_model
op.drop_column("embedding_model", "cloud_provider_id")
# Create the new foreign key constraint
op.create_foreign_key(
"fk_embedding_model_cloud_provider",
"embedding_model",
"embedding_provider",
["provider_type"],
["provider_type"],
)
def downgrade() -> None:
# Drop the foreign key constraint in embedding_model table
op.drop_constraint(
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
)
# Add back the cloud_provider_id column to embedding_model
op.add_column(
"embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True)
)
op.add_column("embedding_provider", sa.Column("id", sa.Integer(), nullable=True))
# Assign incrementing IDs to embedding providers
op.execute(
"""
CREATE SEQUENCE IF NOT EXISTS embedding_provider_id_seq;"""
)
op.execute(
"""
UPDATE embedding_provider SET id = nextval('embedding_provider_id_seq');
"""
)
# Update cloud_provider_id based on provider_type
op.execute(
"""
UPDATE embedding_model
SET cloud_provider_id = CASE
WHEN provider_type IS NULL THEN NULL
ELSE (
SELECT id
FROM embedding_provider
WHERE embedding_provider.provider_type = embedding_model.provider_type
)
END
"""
)
# Drop the provider_type column from embedding_model
op.drop_column("embedding_model", "provider_type")
# Add back the columns to embedding_provider
op.add_column("embedding_provider", sa.Column("name", sa.String(50), nullable=True))
op.add_column(
"embedding_provider", sa.Column("default_model_id", sa.Integer(), nullable=True)
)
# Drop the existing primary key constraint on provider_type
op.drop_constraint("embedding_provider_pkey", "embedding_provider", type_="primary")
# Create the original primary key constraint on id
op.create_primary_key("embedding_provider_pkey", "embedding_provider", ["id"])
# Update name with existing provider_type values
op.execute(
"""
UPDATE embedding_provider
SET name = CASE
WHEN provider_type = 'OPENAI' THEN 'OpenAI'
WHEN provider_type = 'COHERE' THEN 'Cohere'
WHEN provider_type = 'GOOGLE' THEN 'Google'
WHEN provider_type = 'VOYAGE' THEN 'Voyage'
ELSE provider_type
END
"""
)
# Drop the provider_type column from embedding_provider
op.drop_column("embedding_provider", "provider_type")
# Recreate the foreign key constraint in embedding_model table
op.create_foreign_key(
"fk_embedding_model_cloud_provider",
"embedding_model",
"embedding_provider",
["cloud_provider_id"],
["id"],
)
# Recreate the foreign key constraint in embedding_model table
op.create_foreign_key(
"fk_embedding_provider_default_model",
"embedding_provider",
"embedding_model",
["default_model_id"],
["id"],
)

View File

@@ -1,26 +0,0 @@
"""add custom headers to tools
Revision ID: f32615f71aeb
Revises: bd2921608c3a
Create Date: 2024-09-12 20:26:38.932377
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "f32615f71aeb"
down_revision = "bd2921608c3a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"tool", sa.Column("custom_headers", postgresql.JSONB(), nullable=True)
)
def downgrade() -> None:
op.drop_column("tool", "custom_headers")

View File

@@ -1,26 +0,0 @@
"""add has_web_login column to user
Revision ID: f7e58d357687
Revises: ba98eba0f66a
Create Date: 2024-09-07 20:20:54.522620
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f7e58d357687"
down_revision = "ba98eba0f66a"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column("has_web_login", sa.Boolean(), nullable=False, server_default="true"),
)
def downgrade() -> None:
op.drop_column("user", "has_web_login")

View File

@@ -1,81 +1,26 @@
from sqlalchemy.orm import Session
from danswer.access.models import DocumentAccess
from danswer.access.utils import prefix_user_email
from danswer.access.utils import prefix_user
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.db.document import get_access_info_for_document
from danswer.db.document import get_access_info_for_documents
from danswer.db.document import get_acccess_info_for_documents
from danswer.db.models import User
from danswer.utils.variable_functionality import fetch_versioned_implementation
def _get_access_for_document(
document_id: str,
db_session: Session,
) -> DocumentAccess:
info = get_access_info_for_document(
db_session=db_session,
document_id=document_id,
)
return DocumentAccess.build(
user_emails=info[1] if info and info[1] else [],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=info[2] if info else False,
)
def get_access_for_document(
document_id: str,
db_session: Session,
) -> DocumentAccess:
versioned_get_access_for_document_fn = fetch_versioned_implementation(
"danswer.access.access", "_get_access_for_document"
)
return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore
def get_null_document_access() -> DocumentAccess:
return DocumentAccess(
user_emails=set(),
user_groups=set(),
is_public=False,
external_user_emails=set(),
external_user_group_ids=set(),
)
def _get_access_for_documents(
document_ids: list[str],
db_session: Session,
) -> dict[str, DocumentAccess]:
document_access_info = get_access_info_for_documents(
document_access_info = get_acccess_info_for_documents(
db_session=db_session,
document_ids=document_ids,
)
doc_access = {
document_id: DocumentAccess(
user_emails=set([email for email in user_emails if email]),
# MIT version will wipe all groups and external groups on update
user_groups=set(),
is_public=is_public,
external_user_emails=set(),
external_user_group_ids=set(),
)
for document_id, user_emails, is_public in document_access_info
return {
document_id: DocumentAccess.build(user_ids, [], is_public)
for document_id, user_ids, is_public in document_access_info
}
# Sometimes the document has not be indexed by the indexing job yet, in those cases
# the document does not exist and so we use least permissive. Specifically the EE version
# checks the MIT version permissions and creates a superset. This ensures that this flow
# does not fail even if the Document has not yet been indexed.
for doc_id in document_ids:
if doc_id not in doc_access:
doc_access[doc_id] = get_null_document_access()
return doc_access
def get_access_for_documents(
document_ids: list[str],
@@ -97,7 +42,7 @@ def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
matches one entry in the returned set.
"""
if user:
return {prefix_user_email(user.email), PUBLIC_DOC_PAT}
return {prefix_user(str(user.id)), PUBLIC_DOC_PAT}
return {PUBLIC_DOC_PAT}

View File

@@ -1,72 +1,30 @@
from dataclasses import dataclass
from uuid import UUID
from danswer.access.utils import prefix_external_group
from danswer.access.utils import prefix_user_email
from danswer.access.utils import prefix_user
from danswer.access.utils import prefix_user_group
from danswer.configs.constants import PUBLIC_DOC_PAT
@dataclass(frozen=True)
class ExternalAccess:
# Emails of external users with access to the doc externally
external_user_emails: set[str]
# Names or external IDs of groups with access to the doc
external_user_group_ids: set[str]
# Whether the document is public in the external system or Danswer
class DocumentAccess:
user_ids: set[str] # stringified UUIDs
user_groups: set[str] # names of user groups associated with this document
is_public: bool
@dataclass(frozen=True)
class DocumentAccess(ExternalAccess):
# User emails for Danswer users, None indicates admin
user_emails: set[str | None]
# Names of user groups associated with this document
user_groups: set[str]
def to_acl(self) -> set[str]:
return set(
[
prefix_user_email(user_email)
for user_email in self.user_emails
if user_email
]
def to_acl(self) -> list[str]:
return (
[prefix_user(user_id) for user_id in self.user_ids]
+ [prefix_user_group(group_name) for group_name in self.user_groups]
+ [
prefix_user_email(user_email)
for user_email in self.external_user_emails
]
+ [
# The group names are already prefixed by the source type
# This adds an additional prefix of "external_group:"
prefix_external_group(group_name)
for group_name in self.external_user_group_ids
]
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
)
@classmethod
def build(
cls,
user_emails: list[str | None],
user_groups: list[str],
external_user_emails: list[str],
external_user_group_ids: list[str],
is_public: bool,
cls, user_ids: list[UUID | None], user_groups: list[str], is_public: bool
) -> "DocumentAccess":
return cls(
external_user_emails={
prefix_user_email(external_email)
for external_email in external_user_emails
},
external_user_group_ids={
prefix_external_group(external_group_id)
for external_group_id in external_user_group_ids
},
user_emails={
prefix_user_email(user_email)
for user_email in user_emails
if user_email
},
user_ids={str(user_id) for user_id in user_ids if user_id},
user_groups=set(user_groups),
is_public=is_public,
)

View File

@@ -1,24 +1,10 @@
from danswer.configs.constants import DocumentSource
def prefix_user_email(user_email: str) -> str:
"""Prefixes a user email to eliminate collision with group names.
This applies to both a Danswer user and an External user, this is to make the query time
more efficient"""
return f"user_email:{user_email}"
def prefix_user(user_id: str) -> str:
"""Prefixes a user ID to eliminate collision with group names.
This assumes that groups are prefixed with a different prefix."""
return f"user_id:{user_id}"
def prefix_user_group(user_group_name: str) -> str:
"""Prefixes a user group name to eliminate collision with user emails.
"""Prefixes a user group name to eliminate collision with user IDs.
This assumes that user ids are prefixed with a different prefix."""
return f"group:{user_group_name}"
def prefix_external_group(ext_group_name: str) -> str:
"""Prefixes an external group name to eliminate collision with user emails / Danswer groups."""
return f"external_group:{ext_group_name}"
def prefix_group_w_source(ext_group_name: str, source: DocumentSource) -> str:
"""External groups may collide across sources, every source needs its own prefix."""
return f"{source.value.upper()}_{ext_group_name}"

View File

@@ -1,20 +1,20 @@
from typing import cast
from danswer.configs.constants import KV_USER_STORE_KEY
from danswer.key_value_store.factory import get_kv_store
from danswer.key_value_store.interface import JSON_ro
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.dynamic_configs.interface import JSON_ro
def get_invited_users() -> list[str]:
try:
store = get_kv_store()
store = get_dynamic_config_store()
return cast(list, store.load(KV_USER_STORE_KEY))
except KvKeyNotFoundError:
except ConfigNotFoundError:
return list()
def write_invited_users(emails: list[str]) -> int:
store = get_kv_store()
store = get_dynamic_config_store()
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
return len(emails)

View File

@@ -4,29 +4,29 @@ from typing import cast
from danswer.auth.schemas import UserRole
from danswer.configs.constants import KV_NO_AUTH_USER_PREFERENCES_KEY
from danswer.key_value_store.store import KeyValueStore
from danswer.key_value_store.store import KvKeyNotFoundError
from danswer.dynamic_configs.store import ConfigNotFoundError
from danswer.dynamic_configs.store import DynamicConfigStore
from danswer.server.manage.models import UserInfo
from danswer.server.manage.models import UserPreferences
def set_no_auth_user_preferences(
store: KeyValueStore, preferences: UserPreferences
store: DynamicConfigStore, preferences: UserPreferences
) -> None:
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.model_dump())
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.dict())
def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
def load_no_auth_user_preferences(store: DynamicConfigStore) -> UserPreferences:
try:
preferences_data = cast(
Mapping[str, Any], store.load(KV_NO_AUTH_USER_PREFERENCES_KEY)
)
return UserPreferences(**preferences_data)
except KvKeyNotFoundError:
except ConfigNotFoundError:
return UserPreferences(chosen_assistants=None, default_model=None)
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:
def fetch_no_auth_user(store: DynamicConfigStore) -> UserInfo:
return UserInfo(
id="__no_auth_user__",
email="anonymous@danswer.ai",

View File

@@ -5,20 +5,8 @@ from fastapi_users import schemas
class UserRole(str, Enum):
"""
User roles
- Basic can't perform any admin actions
- Admin can perform all admin actions
- Curator can perform admin actions for
groups they are curators of
- Global Curator can perform admin actions
for all groups they are a member of
"""
BASIC = "basic"
ADMIN = "admin"
CURATOR = "curator"
GLOBAL_CURATOR = "global_curator"
class UserStatus(str, Enum):
@@ -33,9 +21,7 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
has_web_login: bool | None = True
class UserUpdate(schemas.BaseUserUpdate):
role: UserRole
has_web_login: bool | None = True

View File

@@ -8,18 +8,13 @@ from email.mime.text import MIMEText
from typing import Optional
from typing import Tuple
import jwt
from email_validator import EmailNotValidError
from email_validator import validate_email
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import Response
from fastapi import status
from fastapi.security import OAuth2PasswordRequestForm
from fastapi_users import BaseUserManager
from fastapi_users import exceptions
from fastapi_users import FastAPIUsers
from fastapi_users import models
from fastapi_users import schemas
@@ -36,19 +31,15 @@ from sqlalchemy.orm import Session
from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
from danswer.auth.schemas import UserUpdate
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DATA_PLANE_SECRET
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import EMAIL_FROM
from danswer.configs.app_configs import EXPECTED_API_KEY
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.configs.app_configs import SMTP_PASS
from danswer.configs.app_configs import SMTP_PORT
from danswer.configs.app_configs import SMTP_SERVER
from danswer.configs.app_configs import SMTP_USER
from danswer.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
from danswer.configs.app_configs import WEB_DOMAIN
@@ -68,7 +59,10 @@ from danswer.db.users import get_user_by_email
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import (
fetch_versioned_implementation,
)
logger = setup_logger()
@@ -87,7 +81,7 @@ def verify_auth_setting() -> None:
"User must choose a valid user authentication method: "
"disabled, basic, or google_oauth"
)
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
logger.info(f"Using Auth Type: {AUTH_TYPE.value}")
def get_display_email(email: str | None, space_less: bool = False) -> str:
@@ -112,28 +106,8 @@ def user_needs_to_be_verified() -> bool:
def verify_email_is_invited(email: str) -> None:
whitelist = get_invited_users()
if not whitelist:
return
if not email:
raise PermissionError("Email must be specified")
email_info = validate_email(email) # can raise EmailNotValidError
for email_whitelist in whitelist:
try:
# normalized emails are now being inserted into the db
# we can remove this normalization on read after some time has passed
email_info_whitelist = validate_email(email_whitelist)
except EmailNotValidError:
continue
# oddly, normalization does not include lowercasing the user part of the
# email address ... which we want to allow
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
return
raise PermissionError("User not on allowed user whitelist")
if (whitelist and email not in whitelist) or not email:
raise PermissionError("User not on allowed user whitelist")
def verify_email_in_whitelist(email: str) -> None:
@@ -190,7 +164,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_create: schemas.UC | UserCreate,
safe: bool = False,
request: Optional[Request] = None,
) -> User:
) -> models.UP:
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if hasattr(user_create, "role"):
@@ -199,27 +173,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
user = None
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
if (
not user.has_web_login
and hasattr(user_create, "has_web_login")
and user_create.has_web_login
):
user_update = UserUpdate(
password=user_create.password,
has_web_login=True,
role=user_create.role,
is_verified=user_create.is_verified,
)
user = await self.update(user_update, user)
else:
raise exceptions.UserAlreadyExists()
return user
return await super().create(user_create, safe=safe, request=request) # type: ignore
async def oauth_callback(
self: "BaseUserManager[models.UOAP, models.ID]",
@@ -249,35 +203,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
is_verified_by_default=is_verified_by_default,
)
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
# re-authenticate that frequently, so by default this is disabled
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
# NOTE: google oauth expires after 1hr. We don't want to force the user to
# re-authenticate that frequently, so for now we'll just ignore this for
# google oauth users
if expires_at and AUTH_TYPE != AuthType.GOOGLE_OAUTH:
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
await self.user_db.update(user, update_dict={"oidc_expiry": oidc_expiry})
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
# otherwise, the oidc expiry will always be old, and the user will never be able to login
if user.oidc_expiry and not TRACK_EXTERNAL_IDP_EXPIRY:
await self.user_db.update(user, update_dict={"oidc_expiry": None})
# Handle case where user has used product outside of web and is now creating an account through web
if not user.has_web_login:
await self.user_db.update(
user,
update_dict={
"is_verified": is_verified_by_default,
"has_web_login": True,
},
)
user.is_verified = is_verified_by_default
user.has_web_login = True
return user
async def on_after_register(
self, user: User, request: Optional[Request] = None
) -> None:
logger.notice(f"User {user.id} has registered.")
logger.info(f"User {user.id} has registered.")
optional_telemetry(
record_type=RecordType.SIGN_UP,
data={"action": "create"},
@@ -287,45 +224,19 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
async def on_after_forgot_password(
self, user: User, token: str, request: Optional[Request] = None
) -> None:
logger.notice(f"User {user.id} has forgot their password. Reset token: {token}")
logger.info(f"User {user.id} has forgot their password. Reset token: {token}")
async def on_after_request_verify(
self, user: User, token: str, request: Optional[Request] = None
) -> None:
verify_email_domain(user.email)
logger.notice(
logger.info(
f"Verification requested for user {user.id}. Verification token: {token}"
)
send_user_verification_email(user.email, token)
async def authenticate(
self, credentials: OAuth2PasswordRequestForm
) -> Optional[User]:
try:
user = await self.get_by_email(credentials.username)
except exceptions.UserNotExists:
self.password_helper.hash(credentials.password)
return None
if not user.has_web_login:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
)
verified, updated_password_hash = self.password_helper.verify_and_update(
credentials.password, user.hashed_password
)
if not verified:
return None
if updated_password_hash is not None:
await self.user_db.update(user, {"hashed_password": updated_password_hash})
return user
async def get_user_manager(
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
@@ -345,6 +256,7 @@ def get_database_strategy(
strategy = DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
)
return strategy
@@ -427,7 +339,6 @@ async def optional_user(
async def double_check_user(
user: User | None,
optional: bool = DISABLE_AUTH,
include_expired: bool = False,
) -> User | None:
if optional:
return None
@@ -444,11 +355,7 @@ async def double_check_user(
detail="Access denied. User is not verified.",
)
if (
user.oidc_expiry
and user.oidc_expiry < datetime.now(timezone.utc)
and not include_expired
):
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User's OIDC token has expired.",
@@ -457,40 +364,12 @@ async def double_check_user(
return user
async def current_user_with_expired_token(
user: User | None = Depends(optional_user),
) -> User | None:
return await double_check_user(user, include_expired=True)
async def current_user(
user: User | None = Depends(optional_user),
) -> User | None:
return await double_check_user(user)
async def current_curator_or_admin_user(
user: User | None = Depends(current_user),
) -> User | None:
if DISABLE_AUTH:
return None
if not user or not hasattr(user, "role"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not authenticated or lacks role information.",
)
allowed_roles = {UserRole.GLOBAL_CURATOR, UserRole.CURATOR, UserRole.ADMIN}
if user.role not in allowed_roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not a curator or admin.",
)
return user
async def current_admin_user(user: User | None = Depends(current_user)) -> User | None:
if DISABLE_AUTH:
return None
@@ -498,37 +377,7 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User must be an admin to perform this action.",
detail="Access denied. User is not an admin.",
)
return user
def get_default_admin_user_emails_() -> list[str]:
# No default seeding available for Danswer MIT
return []
async def control_plane_dep(request: Request) -> None:
api_key = request.headers.get("X-API-KEY")
if api_key != EXPECTED_API_KEY:
logger.warning("Invalid API key")
raise HTTPException(status_code=401, detail="Invalid API key")
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
logger.warning("Invalid authorization header")
raise HTTPException(status_code=401, detail="Invalid authorization header")
token = auth_header.split(" ")[1]
try:
payload = jwt.decode(token, DATA_PLANE_SECRET, algorithms=["HS256"])
if payload.get("scope") != "tenant:create":
logger.warning("Insufficient permissions")
raise HTTPException(status_code=403, detail="Insufficient permissions")
except jwt.ExpiredSignatureError:
logger.warning("Token has expired")
raise HTTPException(status_code=401, detail="Token has expired")
except jwt.InvalidTokenError:
logger.warning("Invalid token")
raise HTTPException(status_code=401, detail="Invalid token")

View File

@@ -1,484 +1,343 @@
import logging
import time
from datetime import timedelta
from typing import Any
from typing import cast
import redis
from celery import bootsteps # type: ignore
from celery import Celery
from celery import current_task
from celery import signals
from celery import Task
from celery.exceptions import WorkerShutdown
from celery.signals import beat_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
from celery.states import READY_STATES
from celery.utils.log import get_task_logger
from celery import Celery # type: ignore
from sqlalchemy.orm import Session
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import ColoredFormatter
from danswer.utils.logger import PlainFormatter
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.celery_utils import should_kick_off_deletion_of_cc_pair
from danswer.background.celery.celery_utils import should_prune_cc_pair
from danswer.background.celery.celery_utils import should_sync_doc_set
from danswer.background.connector_deletion import delete_connector_credential_pair
from danswer.background.connector_deletion import delete_connector_credential_pair_batch
from danswer.background.task_utils import build_celery_task_wrapper
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import POSTGRES_CELERY_APP_NAME
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import InputType
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import delete_document_set
from danswer.db.document_set import fetch_document_sets
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.document_set import fetch_documents_for_document_set_paginated
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import build_connection_string
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import SYNC_DB_API
from danswer.db.models import DocumentSet
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import UpdateRequest
from danswer.utils.logger import setup_logger
logger = setup_logger()
# use this within celery tasks to get celery task specific logging
task_logger = get_task_logger(__name__)
celery_app = Celery(__name__)
celery_app.config_from_object(
"danswer.background.celery.celeryconfig"
) # Load configuration from 'celeryconfig.py'
@signals.task_postrun.connect
def celery_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
"""We handle this signal in order to remove completed tasks
from their respective tasksets. This allows us to track the progress of document set
and user group syncs.
This function runs after any task completes (both success and failure)
Note that this signal does not fire on a task that failed to complete and is going
to be retried.
"""
if not task:
return
task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}")
# logger.debug(f"Result: {retval}")
if state not in READY_STATES:
return
if not task_id:
return
r = get_redis_client()
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
return
if task_id.startswith(RedisDocumentSet.PREFIX):
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
if document_set_id is not None:
rds = RedisDocumentSet(document_set_id)
r.srem(rds.taskset_key, task_id)
return
if task_id.startswith(RedisUserGroup.PREFIX):
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
if usergroup_id is not None:
rug = RedisUserGroup(usergroup_id)
r.srem(rug.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorDeletion.PREFIX):
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
if cc_pair_id is not None:
rcd = RedisConnectorDeletion(cc_pair_id)
r.srem(rcd.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
if cc_pair_id is not None:
rcp = RedisConnectorPruning(cc_pair_id)
r.srem(rcp.taskset_key, task_id)
return
@beat_init.connect
def on_beat_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
SqlEngine.init_engine(pool_size=2, max_overflow=0)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
# decide some initial startup settings based on the celery worker's hostname
# (set at the command line)
hostname = sender.hostname
if hostname.startswith("light"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
elif hostname.startswith("heavy"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
else:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
r = get_redis_client()
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
time_start = time.monotonic()
logger.info("Redis: Readiness check starting.")
while True:
try:
if r.ping():
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
logger.info(
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Redis: Readiness check did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
time.sleep(WAIT_INTERVAL)
logger.info("Redis: Readiness check succeeded. Continuing...")
if not celery_is_worker_primary(sender):
logger.info("Running as a secondary celery worker.")
logger.info("Waiting for primary worker to be ready...")
time_start = time.monotonic()
while True:
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
break
time.monotonic()
time_elapsed = time.monotonic() - time_start
logger.info(
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Primary worker was not ready within the timeout. "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
time.sleep(WAIT_INTERVAL)
logger.info("Wait for primary worker completed successfully. Continuing...")
return
logger.info("Running as the primary celery worker.")
# This is singleton work that should be done on startup exactly once
# by the primary worker
r = get_redis_client()
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
# this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")
sender.primary_worker_lock = lock
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
r.delete(key)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
task_logger.info("worker_ready signal received.")
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
if not celery_is_worker_primary(sender):
return
if not sender.primary_worker_lock:
return
logger.info("Releasing primary worker lock.")
lock = sender.primary_worker_lock
if lock.owned():
lock.release()
sender.primary_worker_lock = None
class CeleryTaskPlainFormatter(PlainFormatter):
def format(self, record: logging.LogRecord) -> str:
task = current_task
if task and task.request:
record.__dict__.update(task_id=task.request.id, task_name=task.name)
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
return super().format(record)
class CeleryTaskColoredFormatter(ColoredFormatter):
def format(self, record: logging.LogRecord) -> str:
task = current_task
if task and task.request:
record.__dict__.update(task_id=task.request.id, task_name=task.name)
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
return super().format(record)
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
# TODO: could unhardcode format and colorize and accept these as options from
# celery's config
# reformats celery's worker logger
root_logger = logging.getLogger()
root_handler = logging.StreamHandler() # Set up a handler for the root logger
root_formatter = ColoredFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
root_handler.setFormatter(root_formatter)
root_logger.addHandler(root_handler) # Apply the handler to the root logger
if logfile:
root_file_handler = logging.FileHandler(logfile)
root_file_formatter = PlainFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
root_file_handler.setFormatter(root_file_formatter)
root_logger.addHandler(root_file_handler)
root_logger.setLevel(loglevel)
# reformats celery's task logger
task_formatter = CeleryTaskColoredFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
task_handler = logging.StreamHandler() # Set up a handler for the task logger
task_handler.setFormatter(task_formatter)
task_logger.addHandler(task_handler) # Apply the handler to the task logger
if logfile:
task_file_handler = logging.FileHandler(logfile)
task_file_formatter = CeleryTaskPlainFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
task_file_handler.setFormatter(task_file_formatter)
task_logger.addHandler(task_file_handler)
task_logger.setLevel(loglevel)
task_logger.propagate = False
class HubPeriodicTask(bootsteps.StartStopStep):
"""Regularly reacquires the primary worker lock outside of the task queue.
Use the task_logger in this class to avoid double logging.
This cannot be done inside a regular beat task because it must run on schedule and
a queue of existing work would starve the task from running.
"""
# it's unclear to me whether using the hub's timer or the bootstep timer is better
requires = {"celery.worker.components:Hub"}
def __init__(self, worker: Any, **kwargs: Any) -> None:
self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds
self.task_tref = None
def start(self, worker: Any) -> None:
if not celery_is_worker_primary(worker):
return
# Access the worker's event loop (hub)
hub = worker.consumer.controller.hub
# Schedule the periodic task
self.task_tref = hub.call_repeatedly(
self.interval, self.run_periodic_task, worker
)
task_logger.info("Scheduled periodic task with hub.")
def run_periodic_task(self, worker: Any) -> None:
try:
if not worker.primary_worker_lock:
return
if not hasattr(worker, "primary_worker_lock"):
return
r = get_redis_client()
lock: redis.lock.Lock = worker.primary_worker_lock
if lock.owned():
task_logger.debug("Reacquiring primary worker lock.")
lock.reacquire()
else:
task_logger.warning(
"Full acquisition of primary worker lock. "
"Reasons could be computer sleep or a clock change."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
task_logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info("Primary worker lock: Acquire succeeded.")
else:
task_logger.error("Primary worker lock: Acquire failed!")
raise TimeoutError("Primary worker lock could not be acquired!")
worker.primary_worker_lock = lock
except Exception:
task_logger.exception("HubPeriodicTask.run_periodic_task exceptioned.")
def stop(self, worker: Any) -> None:
# Cancel the scheduled task when the worker stops
if self.task_tref:
self.task_tref.cancel()
task_logger.info("Canceled periodic task with hub.")
celery_app.steps["worker"].add(HubPeriodicTask)
celery_app.autodiscover_tasks(
[
"danswer.background.celery.tasks.connector_deletion",
"danswer.background.celery.tasks.periodic",
"danswer.background.celery.tasks.pruning",
"danswer.background.celery.tasks.shared",
"danswer.background.celery.tasks.vespa",
]
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=POSTGRES_CELERY_APP_NAME
)
celery_broker_url = f"sqla+{connection_string}"
celery_backend_url = f"db+{connection_string}"
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
_SYNC_BATCH_SIZE = 100
#####
# Tasks that need to be run in job queue, registered via APIs
#
# If imports from this module are needed, use local imports to avoid circular importing
#####
@build_celery_task_wrapper(name_cc_cleanup_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def cleanup_connector_credential_pair_task(
connector_id: int,
credential_id: int,
) -> int:
"""Connector deletion task. This is run as an async task because it is a somewhat slow job.
Needs to potentially update a large number of Postgres and Vespa docs, including deleting them
or updating the ACL"""
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
# validate that the connector / credential pair is deletable
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair:
raise ValueError(
f"Cannot run deletion attempt - connector_credential_pair with Connector ID: "
f"{connector_id} and Credential ID: {credential_id} does not exist."
)
deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed(
connector_credential_pair=cc_pair, db_session=db_session
)
if deletion_attempt_disallowed_reason:
raise ValueError(deletion_attempt_disallowed_reason)
try:
# The bulk of the work is in here, updates Postgres and Vespa
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
return delete_connector_credential_pair(
db_session=db_session,
document_index=document_index,
cc_pair=cc_pair,
)
except Exception as e:
logger.exception(f"Failed to run connector_deletion due to {e}")
raise e
@build_celery_task_wrapper(name_cc_prune_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def prune_documents_task(connector_id: int, credential_id: int) -> None:
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
with Session(get_sqlalchemy_engine()) as db_session:
try:
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair:
logger.warning(f"ccpair not found for {connector_id} {credential_id}")
return
runnable_connector = instantiate_connector(
cc_pair.connector.source,
InputType.PRUNE,
cc_pair.connector.connector_specific_config,
cc_pair.credential,
db_session,
)
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
runnable_connector
)
all_indexed_document_ids = {
doc.id
for doc in get_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
}
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
if len(doc_ids_to_remove) == 0:
logger.info(
f"No docs to prune from {cc_pair.connector.source} connector"
)
return
logger.info(
f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector"
)
delete_connector_credential_pair_batch(
document_ids=doc_ids_to_remove,
connector_id=connector_id,
credential_id=credential_id,
document_index=document_index,
)
except Exception as e:
logger.exception(
f"Failed to run pruning for connector id {connector_id} due to {e}"
)
raise e
@build_celery_task_wrapper(name_document_set_sync_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_document_set_task(document_set_id: int) -> None:
"""For document sets marked as not up to date, sync the state from postgres
into the datastore. Also handles deletions."""
def _sync_document_batch(document_ids: list[str], db_session: Session) -> None:
logger.debug(f"Syncing document sets for: {document_ids}")
# Acquires a lock on the documents so that no other process can modify them
with prepare_to_modify_documents(
db_session=db_session, document_ids=document_ids
):
# get current state of document sets for these documents
document_set_map = {
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
document_ids=document_ids, db_session=db_session
)
}
# update Vespa
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
update_requests = [
UpdateRequest(
document_ids=[document_id],
document_sets=set(document_set_map.get(document_id, [])),
)
for document_id in document_ids
]
document_index.update(update_requests=update_requests)
with Session(get_sqlalchemy_engine()) as db_session:
try:
cursor = None
while True:
document_batch, cursor = fetch_documents_for_document_set_paginated(
document_set_id=document_set_id,
db_session=db_session,
current_only=False,
last_document_id=cursor,
limit=_SYNC_BATCH_SIZE,
)
_sync_document_batch(
document_ids=[document.id for document in document_batch],
db_session=db_session,
)
if cursor is None:
break
# if there are no connectors, then delete the document set. Otherwise, just
# mark it as successfully synced.
document_set = cast(
DocumentSet,
get_document_set_by_id(
db_session=db_session, document_set_id=document_set_id
),
) # casting since we "know" a document set with this ID exists
if not document_set.connector_credential_pairs:
delete_document_set(
document_set_row=document_set, db_session=db_session
)
logger.info(
f"Successfully deleted document set with ID: '{document_set_id}'!"
)
else:
mark_document_set_as_synced(
document_set_id=document_set_id, db_session=db_session
)
logger.info(f"Document set sync for '{document_set_id}' complete!")
except Exception:
logger.exception("Failed to sync document set %s", document_set_id)
raise
#####
# Periodic Tasks
#####
@celery_app.task(
name="check_for_document_sets_sync_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_document_sets_sync_task() -> None:
"""Runs periodically to check if any sync tasks should be run and adds them
to the queue"""
with Session(get_sqlalchemy_engine()) as db_session:
# check if any document sets are not synced
document_set_info = fetch_document_sets(
user_id=None, db_session=db_session, include_outdated=True
)
for document_set, _ in document_set_info:
if should_sync_doc_set(document_set, db_session):
logger.info(f"Syncing the {document_set.name} document set")
sync_document_set_task.apply_async(
kwargs=dict(document_set_id=document_set.id),
)
@celery_app.task(
name="check_for_cc_pair_deletion_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_cc_pair_deletion_task() -> None:
"""Runs periodically to check if any deletion tasks should be run"""
with Session(get_sqlalchemy_engine()) as db_session:
# check if any document sets are not synced
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
if should_kick_off_deletion_of_cc_pair(cc_pair, db_session):
logger.info(f"Deleting the {cc_pair.name} connector credential pair")
cleanup_connector_credential_pair_task.apply_async(
kwargs=dict(
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
),
)
@celery_app.task(
name="check_for_prune_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_prune_task() -> None:
"""Runs periodically to check if any prune tasks should be run and adds them
to the queue"""
with Session(get_sqlalchemy_engine()) as db_session:
all_cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in all_cc_pairs:
if should_prune_cc_pair(
connector=cc_pair.connector,
credential=cc_pair.credential,
db_session=db_session,
):
logger.info(f"Pruning the {cc_pair.connector.name} connector")
prune_documents_task.apply_async(
kwargs=dict(
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
)
#####
# Celery Beat (Periodic Tasks) Settings
#####
celery_app.conf.beat_schedule = {
"check-for-vespa-sync": {
"task": "check_for_vespa_sync_task",
"check-for-document-set-sync": {
"task": "check_for_document_sets_sync_task",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
"check-for-cc-pair-deletion": {
"task": "check_for_cc_pair_deletion_task",
# don't need to check too often, since we kick off a deletion initially
# during the API call that actually marks the CC pair for deletion
"schedule": timedelta(minutes=1),
},
}
celery_app.conf.beat_schedule.update(
{
"check-for-connector-deletion-task": {
"task": "check_for_connector_deletion_task",
# don't need to check too often, since we kick off a deletion initially
# during the API call that actually marks the CC pair for deletion
"schedule": timedelta(seconds=60),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
}
)
celery_app.conf.beat_schedule.update(
{
"check-for-prune": {
"task": "check_for_prune_task_2",
"schedule": timedelta(seconds=60),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
}
)
celery_app.conf.beat_schedule.update(
{
"kombu-message-cleanup": {
"task": "kombu_message_cleanup_task",
"schedule": timedelta(seconds=3600),
"options": {"priority": DanswerCeleryPriority.LOWEST},
},
}
)
celery_app.conf.beat_schedule.update(
{
"monitor-vespa-sync": {
"task": "monitor_vespa_sync",
"task": "check_for_prune_task",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
}
)

View File

@@ -1,480 +0,0 @@
# These are helper objects for tracking the keys we need to write in redis
import time
from abc import ABC
from abc import abstractmethod
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.celeryconfig import CELERY_SEPARATOR
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import construct_document_select_for_connector_credential_pair
from danswer.db.document import (
construct_document_select_for_connector_credential_pair_by_needs_sync,
)
from danswer.db.document_set import construct_document_select_by_docset
from danswer.utils.variable_functionality import fetch_versioned_implementation
class RedisObjectHelper(ABC):
PREFIX = "base"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int):
self._id: int = id
@property
def task_id_prefix(self) -> str:
return f"{self.PREFIX}_{self._id}"
@property
def fence_key(self) -> str:
# example: documentset_fence_1
return f"{self.FENCE_PREFIX}_{self._id}"
@property
def taskset_key(self) -> str:
# example: documentset_taskset_1
return f"{self.TASKSET_PREFIX}_{self._id}"
@staticmethod
def get_id_from_fence_key(key: str) -> int | None:
"""
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
Args:
key (str): The fence key string.
Returns:
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
"""
parts = key.split("_")
if len(parts) != 3:
return None
try:
object_id = int(parts[2])
except ValueError:
return None
return object_id
@staticmethod
def get_id_from_task_id(task_id: str) -> int | None:
"""
Extracts the object ID from a task ID string.
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
- `objectid` is the ID you want to extract,
- `suffix` is another arbitrary string (e.g., a UUID).
Example:
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
this method will return the string `"1"`.
Args:
task_id (str): The task ID string from which to extract the object ID.
Returns:
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
"""
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
parts = task_id.split("_")
if len(parts) != 3:
return None
try:
object_id = int(parts[1])
except ValueError:
return None
return object_id
@abstractmethod
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
) -> int | None:
pass
class RedisDocumentSet(RedisObjectHelper):
PREFIX = "documentset"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
stmt = construct_document_select_by_docset(self._id, current_only=False)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
async_results.append(result)
return len(async_results)
class RedisUserGroup(RedisObjectHelper):
PREFIX = "usergroup"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
try:
construct_document_select_by_usergroup = fetch_versioned_implementation(
"danswer.db.user_group",
"construct_document_select_by_usergroup",
)
except ModuleNotFoundError:
return 0
stmt = construct_document_select_by_usergroup(self._id)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
async_results.append(result)
return len(async_results)
class RedisConnectorCredentialPair(RedisObjectHelper):
"""This class differs from the default in that the taskset used spans
all connectors and is not per connector."""
PREFIX = "connectorsync"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
@classmethod
def get_fence_key(cls) -> str:
return RedisConnectorCredentialPair.FENCE_PREFIX
@classmethod
def get_taskset_key(cls) -> str:
return RedisConnectorCredentialPair.TASKSET_PREFIX
@property
def taskset_key(self) -> str:
"""Notice that this is intentionally reusing the same taskset for all
connector syncs"""
# example: connector_taskset
return f"{self.TASKSET_PREFIX}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
if not cc_pair:
return None
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
class RedisConnectorDeletion(RedisObjectHelper):
PREFIX = "connectordeletion"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
if not cc_pair:
return None
stmt = construct_document_select_for_connector_credential_pair(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(self.taskset_key, custom_task_id)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"document_by_cc_pair_cleanup_task",
kwargs=dict(
document_id=doc.id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
class RedisConnectorPruning(RedisObjectHelper):
"""Celery will kick off a long running generator task to crawl the connector and
find any missing docs, which will each then get a new cleanup task. The progress of
those tasks will then be monitored to completion.
Example rough happy path order:
Check connectorpruning_fence_1
Send generator task with id connectorpruning+generator_1_{uuid}
generator runs connector with callbacks that increment connectorpruning_generator_progress_1
generator creates many subtasks with id connectorpruning+sub_1_{uuid}
in taskset connectorpruning_taskset_1
on completion, generator sets connectorpruning_generator_complete_1
celery postrun removes subtasks from taskset
monitor beat task cleans up when taskset reaches 0 items
"""
PREFIX = "connectorpruning"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire pruning process
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
SUBTASK_PREFIX = PREFIX + "+sub"
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # a signal that contains generator progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # a signal that the generator has finished
def __init__(self, id: int) -> None:
"""id: the cc_pair_id of the connector credential pair"""
super().__init__(id)
self.documents_to_prune: set[str] = set()
@property
def generator_task_id_prefix(self) -> str:
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
@property
def generator_progress_key(self) -> str:
# example: connectorpruning_generator_progress_1
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
@property
def generator_complete_key(self) -> str:
# example: connectorpruning_generator_complete_1
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
@property
def subtask_id_prefix(self) -> str:
return f"{self.SUBTASK_PREFIX}_{self._id}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
if not cc_pair:
return None
for doc_id in self.documents_to_prune:
current_time = time.monotonic()
if lock and current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.subtask_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(self.taskset_key, custom_task_id)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"document_by_cc_pair_cleanup_task",
kwargs=dict(
document_id=doc_id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
def is_pruning(self, db_session: Session, redis_client: Redis) -> bool:
"""A single example of a helper method being refactored into the redis helper"""
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=self._id, db_session=db_session
)
if not cc_pair:
raise ValueError(f"cc_pair_id {self._id} does not exist.")
if redis_client.exists(self.fence_key):
return True
return False
def celery_get_queue_length(queue: str, r: Redis) -> int:
"""This is a redis specific way to get the length of a celery queue.
It is priority aware and knows how to count across the multiple redis lists
used to implement task prioritization.
This operation is not atomic."""
total_length = 0
for i in range(len(DanswerCeleryPriority)):
queue_name = queue
if i > 0:
queue_name += CELERY_SEPARATOR
queue_name += str(i)
length = r.llen(queue_name)
total_length += cast(int, length)
return total_length

View File

@@ -1,11 +1,12 @@
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from typing import Any
from sqlalchemy.orm import Session
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
@@ -15,44 +16,36 @@ from danswer.connectors.interfaces import IdConnector
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import Document
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.enums import TaskStatus
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.engine import get_db_current_time
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import Connector
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential
from danswer.db.models import DocumentSet
from danswer.db.models import TaskQueueState
from danswer.redis.redis_pool import get_redis_client
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.db.tasks import get_latest_task_by_type
from danswer.server.documents.models import DeletionAttemptSnapshot
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _get_deletion_status(
def get_deletion_status(
connector_id: int, credential_id: int, db_session: Session
) -> TaskQueueState | None:
"""We no longer store TaskQueueState in the DB for a deletion attempt.
This function populates TaskQueueState by just checking redis.
"""
cc_pair = get_connector_credential_pair(
connector_id=connector_id, credential_id=credential_id, db_session=db_session
)
if not cc_pair:
return None
rcd = RedisConnectorDeletion(cc_pair.id)
r = get_redis_client()
if not r.exists(rcd.fence_key):
return None
return TaskQueueState(
task_id="", task_name=rcd.fence_key, status=TaskStatus.STARTED
cleanup_task_name = name_cc_cleanup_task(
connector_id=connector_id, credential_id=credential_id
)
return get_latest_task(task_name=cleanup_task_name, db_session=db_session)
def get_deletion_attempt_snapshot(
connector_id: int, credential_id: int, db_session: Session
) -> DeletionAttemptSnapshot | None:
deletion_task = _get_deletion_status(connector_id, credential_id, db_session)
deletion_task = get_deletion_status(connector_id, credential_id, db_session)
if not deletion_task:
return None
@@ -63,19 +56,93 @@ def get_deletion_attempt_snapshot(
)
def should_kick_off_deletion_of_cc_pair(
cc_pair: ConnectorCredentialPair, db_session: Session
) -> bool:
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
return False
if check_deletion_attempt_is_allowed(cc_pair, db_session):
return False
deletion_task = get_deletion_status(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
db_session=db_session,
)
if deletion_task and check_task_is_live_and_not_timed_out(
deletion_task,
db_session,
# 1 hour timeout
timeout=60 * 60,
):
return False
return True
def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool:
if document_set.is_up_to_date:
return False
task_name = name_document_set_sync_task(document_set.id)
latest_sync = get_latest_task(task_name, db_session)
if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session):
logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.")
return False
logger.info(f"Document set {document_set.id} syncing now!")
return True
def should_prune_cc_pair(
connector: Connector, credential: Credential, db_session: Session
) -> bool:
if not connector.prune_freq:
return False
pruning_task_name = name_cc_prune_task(
connector_id=connector.id, credential_id=credential.id
)
last_pruning_task = get_latest_task(pruning_task_name, db_session)
current_db_time = get_db_current_time(db_session)
if not last_pruning_task:
time_since_initialization = current_db_time - connector.time_created
if time_since_initialization.total_seconds() >= connector.prune_freq:
return True
return False
if not ALLOW_SIMULTANEOUS_PRUNING:
pruning_type_task_name = name_cc_prune_task()
last_pruning_type_task = get_latest_task_by_type(
pruning_type_task_name, db_session
)
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
last_pruning_type_task, db_session
):
return False
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
return False
if not last_pruning_task.start_time:
return False
time_since_last_pruning = current_db_time - last_pruning_task.start_time
return time_since_last_pruning.total_seconds() >= connector.prune_freq
def document_batch_to_ids(doc_batch: list[Document]) -> set[str]:
return {doc.id for doc in doc_batch}
def extract_ids_from_runnable_connector(
runnable_connector: BaseConnector,
progress_callback: Callable[[int], None] | None = None,
) -> set[str]:
def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> set[str]:
"""
If the PruneConnector hasnt been implemented for the given connector, just pull
all docs using the load_from_state and grab out the IDs.
Optionally, a callback can be passed to handle the length of each document batch.
all docs using the load_from_state and grab out the IDs
"""
all_connector_doc_ids: set[str] = set()
@@ -98,36 +165,6 @@ def extract_ids_from_runnable_connector(
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
)(document_batch_to_ids)
for doc_batch in doc_batch_generator:
if progress_callback:
progress_callback(len(doc_batch))
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
return all_connector_doc_ids
def celery_is_listening_to_queue(worker: Any, name: str) -> bool:
"""Checks to see if we're listening to the named queue"""
# how to get a list of queues this worker is listening to
# https://stackoverflow.com/questions/29790523/how-to-determine-which-queues-a-celery-worker-is-consuming-at-runtime
queue_names = list(worker.app.amqp.queues.consume_from.keys())
for queue_name in queue_names:
if queue_name == name:
return True
return False
def celery_is_worker_primary(worker: Any) -> bool:
"""There are multiple approaches that could be taken to determine if a celery worker
is 'primary', as defined by us. But the way we do it is to check the hostname set
for the celery worker, which can be done either in celeryconfig.py or on the
command line with '--hostname'."""
hostname = worker.hostname
if hostname.startswith("light"):
return False
if hostname.startswith("heavy"):
return False
return True

View File

@@ -1,97 +0,0 @@
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
from danswer.configs.app_configs import CELERY_BROKER_POOL_LIMIT
from danswer.configs.app_configs import CELERY_RESULT_EXPIRES
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY_RESULT_BACKEND
from danswer.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
from danswer.configs.app_configs import REDIS_HOST
from danswer.configs.app_configs import REDIS_PASSWORD
from danswer.configs.app_configs import REDIS_PORT
from danswer.configs.app_configs import REDIS_SSL
from danswer.configs.app_configs import REDIS_SSL_CA_CERTS
from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
CELERY_SEPARATOR = ":"
CELERY_PASSWORD_PART = ""
if REDIS_PASSWORD:
CELERY_PASSWORD_PART = f":{REDIS_PASSWORD}@"
REDIS_SCHEME = "redis"
# SSL-specific query parameters for Redis URL
SSL_QUERY_PARAMS = ""
if REDIS_SSL:
REDIS_SCHEME = "rediss"
SSL_QUERY_PARAMS = f"?ssl_cert_reqs={REDIS_SSL_CERT_REQS}"
if REDIS_SSL_CA_CERTS:
SSL_QUERY_PARAMS += f"&ssl_ca_certs={REDIS_SSL_CA_CERTS}"
# example celery_broker_url: "redis://:password@localhost:6379/15"
broker_url = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}"
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY_RESULT_BACKEND}{SSL_QUERY_PARAMS}"
# NOTE: prefetch 4 is significantly faster than prefetch 1 for small tasks
# however, prefetching is bad when tasks are lengthy as those tasks
# can stall other tasks.
worker_prefetch_multiplier = 4
broker_connection_retry_on_startup = True
broker_pool_limit = CELERY_BROKER_POOL_LIMIT
# redis broker settings
# https://docs.celeryq.dev/projects/kombu/en/stable/reference/kombu.transport.redis.html
broker_transport_options = {
"priority_steps": list(range(len(DanswerCeleryPriority))),
"sep": CELERY_SEPARATOR,
"queue_order_strategy": "priority",
"retry_on_timeout": True,
"health_check_interval": REDIS_HEALTH_CHECK_INTERVAL,
"socket_keepalive": True,
"socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS,
}
# redis backend settings
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
# there doesn't appear to be a way to set socket_keepalive_options on the redis result backend
redis_socket_keepalive = True
redis_retry_on_timeout = True
redis_backend_health_check_interval = REDIS_HEALTH_CHECK_INTERVAL
task_default_priority = DanswerCeleryPriority.MEDIUM
task_acks_late = True
# It's possible we don't even need celery's result backend, in which case all of the optimization below
# might be irrelevant
result_expires = CELERY_RESULT_EXPIRES # 86400 seconds is the default
# Option 0: Defaults (json serializer, no compression)
# about 1.5 KB per queued task. 1KB in queue, 400B for result, 100 as a child entry in generator result
# Option 1: Reduces generator task result sizes by roughly 20%
# task_compression = "bzip2"
# task_serializer = "pickle"
# result_compression = "bzip2"
# result_serializer = "pickle"
# accept_content=["pickle"]
# Option 2: this significantly reduces the size of the result for generator tasks since the list of children
# can be large. small tasks change very little
# def pickle_bz2_encoder(data):
# return bz2.compress(pickle.dumps(data))
# def pickle_bz2_decoder(data):
# return pickle.loads(bz2.decompress(data))
# from kombu import serialization # To register custom serialization with Celery/Kombu
# serialization.register('pickle-bzip2', pickle_bz2_encoder, pickle_bz2_decoder, 'application/x-pickle-bz2', 'binary')
# task_serializer = "pickle-bzip2"
# result_serializer = "pickle-bzip2"
# accept_content=["pickle", "pickle-bzip2"]

View File

@@ -1,110 +0,0 @@
import redis
from celery import shared_task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import ObjectDeletedError
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_pool import get_redis_client
@shared_task(
name="check_for_connector_deletion_task",
soft_time_limit=JOB_TIMEOUT,
trail=False,
)
def check_for_connector_deletion_task() -> None:
r = get_redis_client()
lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
with Session(get_sqlalchemy_engine()) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
try_generate_document_cc_pair_cleanup_tasks(
cc_pair, db_session, r, lock_beat
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception")
finally:
if lock_beat.owned():
lock_beat.release()
def try_generate_document_cc_pair_cleanup_tasks(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Note that syncing can still be required even if the number of sync tasks generated is zero.
Returns None if no syncing is required.
"""
lock_beat.reacquire()
rcd = RedisConnectorDeletion(cc_pair.id)
# don't generate sync tasks if tasks are still pending
if r.exists(rcd.fence_key):
return None
# we need to refresh the state of the object inside the fence
# to avoid a race condition with db.commit/fence deletion
# at the end of this taskset
try:
db_session.refresh(cc_pair)
except ObjectDeletedError:
return None
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
return None
# add tasks to celery and build up the task set to monitor in redis
r.delete(rcd.taskset_key)
# Add all documents that need to be updated into the queue
task_logger.info(
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
)
tasks_generated = rcd.generate_tasks(celery_app, db_session, r, lock_beat)
if tasks_generated is None:
return None
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
task_logger.info(
f"RedisConnectorDeletion.generate_tasks finished. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
r.set(rcd.fence_key, tasks_generated)
return tasks_generated

View File

@@ -1,137 +0,0 @@
#####
# Periodic Tasks
#####
import json
from typing import Any
from celery import shared_task
from celery.contrib.abortable import AbortableTask # type: ignore
from celery.exceptions import TaskRevokedError
from sqlalchemy import inspect
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import PostgresAdvisoryLocks
from danswer.db.engine import get_sqlalchemy_engine # type: ignore
@shared_task(
name="kombu_message_cleanup_task",
soft_time_limit=JOB_TIMEOUT,
bind=True,
base=AbortableTask,
)
def kombu_message_cleanup_task(self: Any) -> int:
"""Runs periodically to clean up the kombu_message table"""
# we will select messages older than this amount to clean up
KOMBU_MESSAGE_CLEANUP_AGE = 7 # days
KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT = 1000
ctx = {}
ctx["last_processed_id"] = 0
ctx["deleted"] = 0
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
with Session(get_sqlalchemy_engine()) as db_session:
# Exit the task if we can't take the advisory lock
result = db_session.execute(
text("SELECT pg_try_advisory_lock(:id)"),
{"id": PostgresAdvisoryLocks.KOMBU_MESSAGE_CLEANUP_LOCK_ID.value},
).scalar()
if not result:
return 0
while True:
if self.is_aborted():
raise TaskRevokedError("kombu_message_cleanup_task was aborted.")
b = kombu_message_cleanup_task_helper(ctx, db_session)
if not b:
break
db_session.commit()
if ctx["deleted"] > 0:
task_logger.info(
f"Deleted {ctx['deleted']} orphaned messages from kombu_message."
)
return ctx["deleted"]
def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool:
"""
Helper function to clean up old messages from the `kombu_message` table that are no longer relevant.
This function retrieves messages from the `kombu_message` table that are no longer visible and
older than a specified interval. It checks if the corresponding task_id exists in the
`celery_taskmeta` table. If the task_id does not exist, the message is deleted.
Args:
ctx (dict): A context dictionary containing configuration parameters such as:
- 'cleanup_age' (int): The age in days after which messages are considered old.
- 'page_limit' (int): The maximum number of messages to process in one batch.
- 'last_processed_id' (int): The ID of the last processed message to handle pagination.
- 'deleted' (int): A counter to track the number of deleted messages.
db_session (Session): The SQLAlchemy database session for executing queries.
Returns:
bool: Returns True if there are more rows to process, False if not.
"""
inspector = inspect(db_session.bind)
if not inspector:
return False
# With the move to redis as celery's broker and backend, kombu tables may not even exist.
# We can fail silently.
if not inspector.has_table("kombu_message"):
return False
query = text(
"""
SELECT id, timestamp, payload
FROM kombu_message WHERE visible = 'false'
AND timestamp < CURRENT_TIMESTAMP - INTERVAL :interval_days
AND id > :last_processed_id
ORDER BY id
LIMIT :page_limit
"""
)
kombu_messages = db_session.execute(
query,
{
"interval_days": f"{ctx['cleanup_age']} days",
"page_limit": ctx["page_limit"],
"last_processed_id": ctx["last_processed_id"],
},
).fetchall()
if len(kombu_messages) == 0:
return False
for msg in kombu_messages:
payload = json.loads(msg[2])
task_id = payload["headers"]["id"]
# Check if task_id exists in celery_taskmeta
task_exists = db_session.execute(
text("SELECT 1 FROM celery_taskmeta WHERE task_id = :task_id"),
{"task_id": task_id},
).fetchone()
# If task_id does not exist, delete the message
if not task_exists:
result = db_session.execute(
text("DELETE FROM kombu_message WHERE id = :message_id"),
{"message_id": msg[0]},
)
if result.rowcount > 0: # type: ignore
ctx["deleted"] += 1
ctx["last_processed_id"] = msg[0]
return True

View File

@@ -1,239 +0,0 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from uuid import uuid4
import redis
from celery import shared_task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import InputType
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_pool import get_redis_client
@shared_task(
name="check_for_prune_task_2",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_prune_task_2() -> None:
r = get_redis_client()
lock_beat = r.lock(
DanswerRedisLocks.CHECK_PRUNE_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
with Session(get_sqlalchemy_engine()) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
tasks_created = ccpair_pruning_generator_task_creation_helper(
cc_pair, db_session, r, lock_beat
)
if not tasks_created:
continue
task_logger.info(f"Pruning started: cc_pair_id={cc_pair.id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception")
finally:
if lock_beat.owned():
lock_beat.release()
def ccpair_pruning_generator_task_creation_helper(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
) -> int | None:
"""Returns an int if pruning is triggered.
The int represents the number of prune tasks generated (in this case, only one
because the task is a long running generator task.)
Returns None if no pruning is triggered (due to not being needed or
other reasons such as simultaneous pruning restrictions.
Checks for scheduling related conditions, then delegates the rest of the checks to
try_creating_prune_generator_task.
"""
lock_beat.reacquire()
# skip pruning if no prune frequency is set
# pruning can still be forced via the API which will run a pruning task directly
if not cc_pair.connector.prune_freq:
return None
# skip pruning if the next scheduled prune time hasn't been reached yet
last_pruned = cc_pair.last_pruned
if not last_pruned:
# if never pruned, use the connector time created as the last_pruned time
last_pruned = cc_pair.connector.time_created
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
if datetime.now(timezone.utc) < next_prune:
return None
return try_creating_prune_generator_task(cc_pair, db_session, r)
def try_creating_prune_generator_task(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
) -> int | None:
"""Checks for any conditions that should block the pruning generator task from being
created, then creates the task.
Does not check for scheduling related conditions as this function
is used to trigger prunes immediately.
"""
if not ALLOW_SIMULTANEOUS_PRUNING:
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
return None
rcp = RedisConnectorPruning(cc_pair.id)
# skip pruning if already pruning
if r.exists(rcp.fence_key):
return None
# skip pruning if the cc_pair is deleting
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# add a long running generator task to the queue
r.delete(rcp.generator_complete_key)
r.delete(rcp.taskset_key)
custom_task_id = f"{rcp.generator_task_id_prefix}_{uuid4()}"
celery_app.send_task(
"connector_pruning_generator_task",
kwargs=dict(
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
),
queue=DanswerCeleryQueues.CONNECTOR_PRUNING,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
# set this only after all tasks have been added
r.set(rcp.fence_key, 1)
return 1
@shared_task(name="connector_pruning_generator_task", soft_time_limit=JOB_TIMEOUT)
def connector_pruning_generator_task(connector_id: int, credential_id: int) -> None:
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
r = get_redis_client()
with Session(get_sqlalchemy_engine()) as db_session:
try:
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair:
task_logger.warning(
f"ccpair not found for {connector_id} {credential_id}"
)
return
rcp = RedisConnectorPruning(cc_pair.id)
# Define the callback function
def redis_increment_callback(amount: int) -> None:
r.incrby(rcp.generator_progress_key, amount)
runnable_connector = instantiate_connector(
db_session,
cc_pair.connector.source,
InputType.PRUNE,
cc_pair.connector.connector_specific_config,
cc_pair.credential,
)
# a list of docs in the source
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
runnable_connector, redis_increment_callback
)
# a list of docs in our local index
all_indexed_document_ids = {
doc.id
for doc in get_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
}
# generate list of docs to remove (no longer in the source)
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
task_logger.info(
f"Pruning set collected: "
f"cc_pair_id={cc_pair.id} "
f"docs_to_remove={len(doc_ids_to_remove)} "
f"doc_source={cc_pair.connector.source}"
)
rcp.documents_to_prune = set(doc_ids_to_remove)
task_logger.info(
f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}"
)
tasks_generated = rcp.generate_tasks(celery_app, db_session, r, None)
if tasks_generated is None:
return None
task_logger.info(
f"RedisConnectorPruning.generate_tasks finished. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
)
r.set(rcp.generator_complete_key, tasks_generated)
except Exception as e:
task_logger.exception(
f"Failed to run pruning for connector id {connector_id}."
)
r.delete(rcp.generator_progress_key)
r.delete(rcp.taskset_key)
r.delete(rcp.fence_key)
raise e

View File

@@ -1,113 +0,0 @@
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_document
from danswer.background.celery.celery_app import task_logger
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document
from danswer.db.document import get_document_connector_count
from danswer.db.document import mark_document_as_synced
from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.engine import get_sqlalchemy_engine
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
@shared_task(
name="document_by_cc_pair_cleanup_task",
bind=True,
soft_time_limit=45,
time_limit=60,
max_retries=3,
)
def document_by_cc_pair_cleanup_task(
self: Task, document_id: str, connector_id: int, credential_id: int
) -> bool:
"""A lightweight subtask used to clean up document to cc pair relationships.
Created by connection deletion and connector pruning parent tasks."""
"""
To delete a connector / credential pair:
(1) find all documents associated with connector / credential pair where there
this the is only connector / credential pair that has indexed it
(2) delete all documents from document stores
(3) delete all entries from postgres
(4) find all documents associated with connector / credential pair where there
are multiple connector / credential pairs that have indexed it
(5) update document store entries to remove access associated with the
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
task_logger.info(f"document_id={document_id}")
try:
with Session(get_sqlalchemy_engine()) as db_session:
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
count = get_document_connector_count(db_session, document_id)
if count == 1:
# count == 1 means this is the only remaining cc_pair reference to the doc
# delete it from vespa and the db
document_index.delete(doc_ids=[document_id])
delete_documents_complete__no_commit(
db_session=db_session,
document_ids=[document_id],
)
elif count > 1:
# count > 1 means the document still has cc_pair references
doc = get_document(document_id, db_session)
if not doc:
return False
# the below functions do not include cc_pairs being deleted.
# i.e. they will correctly omit access for the current cc_pair
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
doc_sets = fetch_document_sets_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
fields = VespaDocumentFields(
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
document_index.update_single(document_id, fields=fields)
# there are still other cc_pair references to the doc, so just resync to Vespa
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_id=document_id,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
mark_document_as_synced(document_id, db_session)
else:
pass
db_session.commit()
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
except Exception as e:
task_logger.exception("Unexpected exception")
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown)
return True

View File

@@ -1,576 +0,0 @@
import traceback
from typing import cast
import redis
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_document
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector import mark_ccpair_as_pruned
from danswer.db.connector_credential_pair import add_deletion_failure_message
from danswer.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit,
)
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.document import count_documents_by_needs_sync
from danswer.db.document import get_document
from danswer.db.document import mark_document_as_synced
from danswer.db.document_set import delete_document_set
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from danswer.db.document_set import fetch_document_sets
from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.models import DocumentSet
from danswer.db.models import UserGroup
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import UpdateRequest
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from danswer.utils.variable_functionality import noop_fallback
# celery auto associates tasks created inside another task,
# which bloats the result metadata considerably. trail=False prevents this.
@shared_task(
name="check_for_vespa_sync_task",
soft_time_limit=JOB_TIMEOUT,
trail=False,
)
def check_for_vespa_sync_task() -> None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
r = get_redis_client()
lock_beat = r.lock(
DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
with Session(get_sqlalchemy_engine()) as db_session:
try_generate_stale_document_sync_tasks(db_session, r, lock_beat)
# check if any document sets are not synced
document_set_info = fetch_document_sets(
user_id=None, db_session=db_session, include_outdated=True
)
for document_set, _ in document_set_info:
try_generate_document_set_sync_tasks(
document_set, db_session, r, lock_beat
)
# check if any user groups are not synced
try:
fetch_user_groups = fetch_versioned_implementation(
"danswer.db.user_group", "fetch_user_groups"
)
user_groups = fetch_user_groups(
db_session=db_session, only_up_to_date=False
)
for usergroup in user_groups:
try_generate_user_group_sync_tasks(
usergroup, db_session, r, lock_beat
)
except ModuleNotFoundError:
# Always exceptions on the MIT version, which is expected
pass
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception")
finally:
if lock_beat.owned():
lock_beat.release()
def try_generate_stale_document_sync_tasks(
db_session: Session, r: Redis, lock_beat: redis.lock.Lock
) -> int | None:
# the fence is up, do nothing
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
return None
r.delete(RedisConnectorCredentialPair.get_taskset_key()) # delete the taskset
# add tasks to celery and build up the task set to monitor in redis
stale_doc_count = count_documents_by_needs_sync(db_session)
if stale_doc_count == 0:
return None
task_logger.info(
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
)
task_logger.info("RedisConnector.generate_tasks starting by cc_pair.")
# rkuo: we could technically sync all stale docs in one big pass.
# but I feel it's more understandable to group the docs by cc_pair
total_tasks_generated = 0
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
rc = RedisConnectorCredentialPair(cc_pair.id)
tasks_generated = rc.generate_tasks(celery_app, db_session, r, lock_beat)
if tasks_generated is None:
continue
if tasks_generated == 0:
continue
task_logger.info(
f"RedisConnector.generate_tasks finished for single cc_pair. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
)
total_tasks_generated += tasks_generated
task_logger.info(
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
)
r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated)
return total_tasks_generated
def try_generate_document_set_sync_tasks(
document_set: DocumentSet, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
) -> int | None:
lock_beat.reacquire()
rds = RedisDocumentSet(document_set.id)
# don't generate document set sync tasks if tasks are still pending
if r.exists(rds.fence_key):
return None
# don't generate sync tasks if we're up to date
# race condition with the monitor/cleanup function if we use a cached result!
db_session.refresh(document_set)
if document_set.is_up_to_date:
return None
# add tasks to celery and build up the task set to monitor in redis
r.delete(rds.taskset_key)
task_logger.info(
f"RedisDocumentSet.generate_tasks starting. document_set_id={document_set.id}"
)
# Add all documents that need to be updated into the queue
tasks_generated = rds.generate_tasks(celery_app, db_session, r, lock_beat)
if tasks_generated is None:
return None
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
task_logger.info(
f"RedisDocumentSet.generate_tasks finished. "
f"document_set_id={document_set.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
r.set(rds.fence_key, tasks_generated)
return tasks_generated
def try_generate_user_group_sync_tasks(
usergroup: UserGroup, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
) -> int | None:
lock_beat.reacquire()
rug = RedisUserGroup(usergroup.id)
# don't generate sync tasks if tasks are still pending
if r.exists(rug.fence_key):
return None
# race condition with the monitor/cleanup function if we use a cached result!
db_session.refresh(usergroup)
if usergroup.is_up_to_date:
return None
# add tasks to celery and build up the task set to monitor in redis
r.delete(rug.taskset_key)
# Add all documents that need to be updated into the queue
task_logger.info(
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
)
tasks_generated = rug.generate_tasks(celery_app, db_session, r, lock_beat)
if tasks_generated is None:
return None
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
task_logger.info(
f"RedisUserGroup.generate_tasks finished. "
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
r.set(rug.fence_key, tasks_generated)
return tasks_generated
def monitor_connector_taskset(r: Redis) -> None:
fence_value = r.get(RedisConnectorCredentialPair.get_fence_key())
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = r.scard(RedisConnectorCredentialPair.get_taskset_key())
task_logger.info(
f"Stale document sync progress: remaining={count} initial={initial_count}"
)
if count == 0:
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
task_logger.info(f"Successfully synced stale documents. count={initial_count}")
def monitor_document_set_taskset(
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key)
if document_set_id is None:
task_logger.warning(f"could not parse document set id from {fence_key}")
return
rds = RedisDocumentSet(document_set_id)
fence_value = r.get(rds.fence_key)
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rds.taskset_key))
task_logger.info(
f"Document set sync progress: document_set_id={document_set_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
document_set = cast(
DocumentSet,
get_document_set_by_id(db_session=db_session, document_set_id=document_set_id),
) # casting since we "know" a document set with this ID exists
if document_set:
if not document_set.connector_credential_pairs:
# if there are no connectors, then delete the document set.
delete_document_set(document_set_row=document_set, db_session=db_session)
task_logger.info(
f"Successfully deleted document set with ID: '{document_set_id}'!"
)
else:
mark_document_set_as_synced(document_set_id, db_session)
task_logger.info(
f"Successfully synced document set with ID: '{document_set_id}'!"
)
r.delete(rds.taskset_key)
r.delete(rds.fence_key)
def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
if cc_pair_id is None:
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
return
rcd = RedisConnectorDeletion(cc_pair_id)
fence_value = r.get(rcd.fence_key)
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rcd.taskset_key))
task_logger.info(
f"Connector deletion progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
with Session(get_sqlalchemy_engine()) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if not cc_pair:
task_logger.warning(
f"monitor_connector_deletion_taskset - cc_pair_id not found: cc_pair_id={cc_pair_id}"
)
return
try:
# clean up the rest of the related Postgres entities
# index attempts
delete_index_attempts(
db_session=db_session,
cc_pair_id=cc_pair.id,
)
# document sets
delete_document_set_cc_pair_relationship__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# user groups
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
"danswer.db.user_group",
"delete_user_group_cc_pair_relationship__no_commit",
noop_fallback,
)
cleanup_user_groups(
cc_pair_id=cc_pair.id,
db_session=db_session,
)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# if there are no credentials left, delete the connector
connector = fetch_connector_by_id(
db_session=db_session,
connector_id=cc_pair.connector_id,
)
if not connector or not len(connector.credentials):
task_logger.info(
"Found no credentials left for connector, deleting connector"
)
db_session.delete(connector)
db_session.commit()
except Exception as e:
stack_trace = traceback.format_exc()
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
add_deletion_failure_message(db_session, cc_pair.id, error_message)
task_logger.exception(
f"Failed to run connector_deletion. "
f"cc_pair_id={cc_pair_id} connector_id={cc_pair.connector_id} credential_id={cc_pair.credential_id}"
)
raise e
task_logger.info(
f"Successfully deleted cc_pair: "
f"cc_pair_id={cc_pair_id} "
f"connector_id={cc_pair.connector_id} "
f"credential_id={cc_pair.credential_id} "
f"docs_deleted={initial_count}"
)
r.delete(rcd.taskset_key)
r.delete(rcd.fence_key)
def monitor_ccpair_pruning_taskset(
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id = RedisConnectorPruning.get_id_from_fence_key(fence_key)
if cc_pair_id is None:
task_logger.warning(
f"monitor_connector_pruning_taskset: could not parse cc_pair_id from {fence_key}"
)
return
rcp = RedisConnectorPruning(cc_pair_id)
fence_value = r.get(rcp.fence_key)
if fence_value is None:
return
generator_value = r.get(rcp.generator_complete_key)
if generator_value is None:
return
try:
initial_count = int(cast(int, generator_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rcp.taskset_key))
task_logger.info(
f"Connector pruning progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
mark_ccpair_as_pruned(cc_pair_id, db_session)
task_logger.info(
f"Successfully pruned connector credential pair. cc_pair_id={cc_pair_id}"
)
r.delete(rcp.taskset_key)
r.delete(rcp.generator_progress_key)
r.delete(rcp.generator_complete_key)
r.delete(rcp.fence_key)
@shared_task(name="monitor_vespa_sync", soft_time_limit=300)
def monitor_vespa_sync() -> None:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
It scans for fence values and then gets the counts of any associated tasksets.
If the count is 0, that means all tasks finished and we should clean up.
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
do anything too expensive in this function!
"""
r = get_redis_client()
lock_beat = r.lock(
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# prevent overlapping tasks
if not lock_beat.acquire(blocking=False):
return
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
monitor_connector_deletion_taskset(key_bytes, r)
with Session(get_sqlalchemy_engine()) as db_session:
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
monitor_document_set_taskset(key_bytes, r, db_session)
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
monitor_usergroup_taskset = (
fetch_versioned_implementation_with_fallback(
"danswer.background.celery.tasks.vespa.tasks",
"monitor_usergroup_taskset",
noop_fallback,
)
)
monitor_usergroup_taskset(key_bytes, r, db_session)
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
# uncomment for debugging if needed
# r_celery = celery_app.broker_connection().channel().client
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
# task_logger.warning(f"queue={DanswerCeleryQueues.VESPA_METADATA_SYNC} length={length}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
finally:
if lock_beat.owned():
lock_beat.release()
@shared_task(
name="vespa_metadata_sync_task",
bind=True,
soft_time_limit=45,
time_limit=60,
max_retries=3,
)
def vespa_metadata_sync_task(self: Task, document_id: str) -> bool:
task_logger.info(f"document_id={document_id}")
try:
with Session(get_sqlalchemy_engine()) as db_session:
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
doc = get_document(document_id, db_session)
if not doc:
return False
# document set sync
doc_sets = fetch_document_sets_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
# User group sync
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
update_request = UpdateRequest(
document_ids=[document_id],
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# update Vespa
document_index.update(update_requests=[update_request])
# update db last. Worst case = we crash right before this and
# the sync might repeat again later
mark_document_as_synced(document_id, db_session)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
except Exception as e:
task_logger.exception("Unexpected exception")
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown)
return True

View File

@@ -0,0 +1,196 @@
"""
To delete a connector / credential pair:
(1) find all documents associated with connector / credential pair where there
this the is only connector / credential pair that has indexed it
(2) delete all documents from document stores
(3) delete all entries from postgres
(4) find all documents associated with connector / credential pair where there
are multiple connector / credential pairs that have indexed it
(5) update document store entries to remove access associated with the
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_documents
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit,
)
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document_connector_cnts
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.models import ConnectorCredentialPair
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import UpdateRequest
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from danswer.utils.variable_functionality import noop_fallback
logger = setup_logger()
_DELETION_BATCH_SIZE = 1000
def delete_connector_credential_pair_batch(
document_ids: list[str],
connector_id: int,
credential_id: int,
document_index: DocumentIndex,
) -> None:
"""
Removes a batch of documents ids from a cc-pair. If no other cc-pair uses a document anymore
it gets permanently deleted.
"""
with Session(get_sqlalchemy_engine()) as db_session:
# acquire lock for all documents in this batch so that indexing can't
# override the deletion
with prepare_to_modify_documents(
db_session=db_session, document_ids=document_ids
):
document_connector_cnts = get_document_connector_cnts(
db_session=db_session, document_ids=document_ids
)
# figure out which docs need to be completely deleted
document_ids_to_delete = [
document_id for document_id, cnt in document_connector_cnts if cnt == 1
]
logger.debug(f"Deleting documents: {document_ids_to_delete}")
document_index.delete(doc_ids=document_ids_to_delete)
delete_documents_complete__no_commit(
db_session=db_session,
document_ids=document_ids_to_delete,
)
# figure out which docs need to be updated
document_ids_to_update = [
document_id for document_id, cnt in document_connector_cnts if cnt > 1
]
# maps document id to list of document set names
new_doc_sets_for_documents: dict[str, set[str]] = {
document_id_and_document_set_names_tuple[0]: set(
document_id_and_document_set_names_tuple[1]
)
for document_id_and_document_set_names_tuple in fetch_document_sets_for_documents(
db_session=db_session,
document_ids=document_ids_to_update,
)
}
# determine future ACLs for documents in batch
access_for_documents = get_access_for_documents(
document_ids=document_ids_to_update,
db_session=db_session,
)
# update Vespa
logger.debug(f"Updating documents: {document_ids_to_update}")
update_requests = [
UpdateRequest(
document_ids=[document_id],
access=access,
document_sets=new_doc_sets_for_documents[document_id],
)
for document_id, access in access_for_documents.items()
]
document_index.update(update_requests=update_requests)
# clean up Postgres
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_ids=document_ids_to_update,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
db_session.commit()
def delete_connector_credential_pair(
db_session: Session,
document_index: DocumentIndex,
cc_pair: ConnectorCredentialPair,
) -> int:
connector_id = cc_pair.connector_id
credential_id = cc_pair.credential_id
num_docs_deleted = 0
while True:
documents = get_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
limit=_DELETION_BATCH_SIZE,
)
if not documents:
break
delete_connector_credential_pair_batch(
document_ids=[document.id for document in documents],
connector_id=connector_id,
credential_id=credential_id,
document_index=document_index,
)
num_docs_deleted += len(documents)
# clean up the rest of the related Postgres entities
# index attempts
delete_index_attempts(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
# document sets
delete_document_set_cc_pair_relationship__no_commit(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
# user groups
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
"danswer.db.user_group",
"delete_user_group_cc_pair_relationship__no_commit",
noop_fallback,
)
cleanup_user_groups(
cc_pair_id=cc_pair.id,
db_session=db_session,
)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
# if there are no credentials left, delete the connector
connector = fetch_connector_by_id(
db_session=db_session,
connector_id=connector_id,
)
if not connector or not len(connector.credentials):
logger.debug("Found no credentials left for connector, deleting connector")
db_session.delete(connector)
db_session.commit()
logger.info(
"Successfully deleted connector_credential_pair with connector_id:"
f" '{connector_id}' and credential_id: '{credential_id}'. Deleted {num_docs_deleted} docs."
)
return num_docs_deleted

View File

@@ -11,10 +11,12 @@ from danswer.background.indexing.tracer import DanswerTracer
from danswer.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
from danswer.configs.app_configs import INDEXING_TRACER_INTERVAL
from danswer.configs.app_configs import POLL_CONNECTOR_OFFSET
from danswer.connectors.connector_runner import ConnectorRunner
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import IndexAttemptMetadata
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.connectors.models import InputType
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.engine import get_sqlalchemy_engine
@@ -22,7 +24,6 @@ from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_in_progress
from danswer.db.index_attempt import mark_attempt_partially_succeeded
from danswer.db.index_attempt import mark_attempt_succeeded
from danswer.db.index_attempt import update_docs_indexed
from danswer.db.models import IndexAttempt
@@ -30,7 +31,6 @@ from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import IndexAttemptSingleton
from danswer.utils.logger import setup_logger
@@ -41,16 +41,16 @@ logger = setup_logger()
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
def _get_connector_runner(
def _get_document_generator(
db_session: Session,
attempt: IndexAttempt,
start_time: datetime,
end_time: datetime,
) -> ConnectorRunner:
) -> GenerateDocumentsOutput:
"""
NOTE: `start_time` and `end_time` are only used for poll connectors
Returns an iterator of document batches and whether the returned documents
Returns an interator of document batches and whether the returned documents
are the complete list of existing documents of the connector. If the task
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
"""
@@ -58,32 +58,49 @@ def _get_connector_runner(
try:
runnable_connector = instantiate_connector(
db_session=db_session,
source=attempt.connector_credential_pair.connector.source,
input_type=task,
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
credential=attempt.connector_credential_pair.credential,
attempt.connector_credential_pair.connector.source,
task,
attempt.connector_credential_pair.connector.connector_specific_config,
attempt.connector_credential_pair.credential,
db_session,
)
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")
# since we failed to even instantiate the connector, we pause the CCPair since
# it will never succeed
cc_pair = get_connector_credential_pair_from_id(
attempt.connector_credential_pair.id, db_session
update_connector_credential_pair(
db_session=db_session,
connector_id=attempt.connector_credential_pair.connector.id,
credential_id=attempt.connector_credential_pair.credential.id,
status=ConnectorCredentialPairStatus.PAUSED,
)
if cc_pair and cc_pair.status == ConnectorCredentialPairStatus.ACTIVE:
update_connector_credential_pair(
db_session=db_session,
connector_id=attempt.connector_credential_pair.connector.id,
credential_id=attempt.connector_credential_pair.credential.id,
status=ConnectorCredentialPairStatus.PAUSED,
)
raise e
return ConnectorRunner(
connector=runnable_connector, time_range=(start_time, end_time)
)
if task == InputType.LOAD_STATE:
assert isinstance(runnable_connector, LoadConnector)
doc_batch_generator = runnable_connector.load_from_state()
elif task == InputType.POLL:
assert isinstance(runnable_connector, PollConnector)
if (
attempt.connector_credential_pair.connector_id is None
or attempt.connector_credential_pair.connector_id is None
):
raise ValueError(
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
f"can't fetch time range."
)
logger.info(f"Polling for updates between {start_time} and {end_time}")
doc_batch_generator = runnable_connector.poll_source(
start=start_time.timestamp(), end=end_time.timestamp()
)
else:
# Event types cannot be handled by a background type
raise RuntimeError(f"Invalid task type: {task}")
return doc_batch_generator
def _run_indexing(
@@ -97,71 +114,55 @@ def _run_indexing(
"""
start_time = time.time()
search_settings = index_attempt.search_settings
index_name = search_settings.index_name
db_embedding_model = index_attempt.embedding_model
index_name = db_embedding_model.index_name
# Only update cc-pair status for primary index jobs
# Secondary index syncs at the end when swapping
is_primary = search_settings.status == IndexModelStatus.PRESENT
is_primary = index_attempt.embedding_model.status == IndexModelStatus.PRESENT
# Indexing is only done into one index at a time
document_index = get_default_document_index(
primary_index_name=index_name, secondary_index_name=None
)
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings,
heartbeat=IndexingHeartbeat(
index_attempt_id=index_attempt.id,
db_session=db_session,
# let the world know we're still making progress after
# every 10 batches
freq=10,
),
embedding_model = DefaultIndexingEmbedder.from_db_embedding_model(
db_embedding_model
)
indexing_pipeline = build_indexing_pipeline(
attempt_id=index_attempt.id,
embedder=embedding_model,
document_index=document_index,
ignore_time_skip=(
index_attempt.from_beginning
or (search_settings.status == IndexModelStatus.FUTURE)
),
ignore_time_skip=index_attempt.from_beginning
or (db_embedding_model.status == IndexModelStatus.FUTURE),
db_session=db_session,
)
db_cc_pair = index_attempt.connector_credential_pair
db_connector = index_attempt.connector_credential_pair.connector
db_credential = index_attempt.connector_credential_pair.credential
earliest_index_time = (
db_connector.indexing_start.timestamp() if db_connector.indexing_start else 0
)
last_successful_index_time = (
earliest_index_time
if index_attempt.from_beginning
else get_last_successful_attempt_time(
connector_id=db_connector.id,
credential_id=db_credential.id,
earliest_index=earliest_index_time,
search_settings=index_attempt.search_settings,
db_session=db_session,
db_connector.indexing_start.timestamp()
if index_attempt.from_beginning and db_connector.indexing_start is not None
else (
0.0
if index_attempt.from_beginning
else get_last_successful_attempt_time(
connector_id=db_connector.id,
credential_id=db_credential.id,
embedding_model=index_attempt.embedding_model,
db_session=db_session,
)
)
)
if INDEXING_TRACER_INTERVAL > 0:
logger.debug(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
logger.info(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
tracer = DanswerTracer()
tracer.start()
tracer.snap()
index_attempt_md = IndexAttemptMetadata(
connector_id=db_connector.id,
credential_id=db_credential.id,
)
batch_num = 0
net_doc_change = 0
document_count = 0
chunk_count = 0
@@ -180,7 +181,7 @@ def _run_indexing(
datetime(1970, 1, 1, tzinfo=timezone.utc),
)
connector_runner = _get_connector_runner(
doc_batch_generator = _get_document_generator(
db_session=db_session,
attempt=index_attempt,
start_time=window_start,
@@ -192,19 +193,15 @@ def _run_indexing(
tracer_counter = 0
if INDEXING_TRACER_INTERVAL > 0:
tracer.snap()
for doc_batch in connector_runner.run():
for doc_batch in doc_batch_generator:
# Check if connector is disabled mid run and stop if so unless it's the secondary
# index being built. We want to populate it even for paused connectors
# Often paused connectors are sources that aren't updated frequently but the
# contents still need to be initially pulled.
db_session.refresh(db_connector)
if (
(
db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
and search_settings.status != IndexModelStatus.FUTURE
)
# if it's deleting, we don't care if this is a secondary index
or db_cc_pair.status == ConnectorCredentialPairStatus.DELETING
db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
and db_embedding_model.status != IndexModelStatus.FUTURE
):
# let the `except` block handle this
raise RuntimeError("Connector was disabled mid run")
@@ -231,13 +228,13 @@ def _run_indexing(
logger.debug(f"Indexing batch of documents: {batch_description}")
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
new_docs, total_batch_chunks = indexing_pipeline(
document_batch=doc_batch,
index_attempt_metadata=index_attempt_md,
index_attempt_metadata=IndexAttemptMetadata(
connector_id=db_connector.id,
credential_id=db_credential.id,
),
)
batch_num += 1
net_doc_change += new_docs
chunk_count += total_batch_chunks
document_count += len(doc_batch)
@@ -264,7 +261,7 @@ def _run_indexing(
INDEXING_TRACER_INTERVAL > 0
and tracer_counter % INDEXING_TRACER_INTERVAL == 0
):
logger.debug(
logger.info(
f"Running trace comparison for batch {tracer_counter}. interval={INDEXING_TRACER_INTERVAL}"
)
tracer.snap()
@@ -280,7 +277,7 @@ def _run_indexing(
run_dt=run_end_dt,
)
except Exception as e:
logger.exception(
logger.info(
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
)
# Only mark the attempt as a complete failure if this is the first indexing window.
@@ -292,7 +289,7 @@ def _run_indexing(
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or not db_cc_pair.status.is_active()
or db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
@@ -318,52 +315,15 @@ def _run_indexing(
break
if INDEXING_TRACER_INTERVAL > 0:
logger.debug(
logger.info(
f"Running trace comparison between start and end of indexing. {tracer_counter} batches processed."
)
tracer.snap()
tracer.log_first_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
tracer.stop()
logger.debug("Memory tracer stopped.")
if (
index_attempt_md.num_exceptions > 0
and index_attempt_md.num_exceptions >= batch_num
):
mark_attempt_failed(
index_attempt,
db_session,
failure_reason="All batches exceptioned.",
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=index_attempt.connector_credential_pair.connector.id,
credential_id=index_attempt.connector_credential_pair.credential.id,
)
raise Exception(
f"Connector failed - All batches exceptioned: batches={batch_num}"
)
elapsed_time = time.time() - start_time
if index_attempt_md.num_exceptions == 0:
mark_attempt_succeeded(index_attempt, db_session)
logger.info(
f"Connector succeeded: "
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
)
else:
mark_attempt_partially_succeeded(index_attempt, db_session)
logger.info(
f"Connector completed with some errors: "
f"exceptions={index_attempt_md.num_exceptions} "
f"batches={batch_num} "
f"docs={document_count} "
f"chunks={chunk_count} "
f"elapsed={elapsed_time:.2f}s"
)
logger.info("Memory tracer stopped.")
mark_attempt_succeeded(index_attempt, db_session)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
@@ -372,6 +332,11 @@ def _run_indexing(
run_dt=run_end_dt,
)
elapsed_time = time.time() - start_time
logger.info(
f"Connector succeeded: docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
)
def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexAttempt:
# make sure that the index attempt can't change in between checking the
@@ -400,22 +365,17 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
return attempt
def run_indexing_entrypoint(
index_attempt_id: int, connector_credential_pair_id: int, is_ee: bool = False
) -> None:
def run_indexing_entrypoint(index_attempt_id: int, is_ee: bool = False) -> None:
"""Entrypoint for indexing run when using dask distributed.
Wraps the actual logic in a `try` block so that we can catch any exceptions
and mark the attempt as failed."""
try:
if is_ee:
global_version.set_ee()
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
IndexAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
)
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
with Session(get_sqlalchemy_engine()) as db_session:
# make sure that it is valid to run this indexing attempt + mark it

View File

@@ -48,9 +48,9 @@ class DanswerTracer:
stats = self.snapshot.statistics("traceback")
for s in stats[:numEntries]:
logger.debug(f"Tracer snap: {s}")
logger.info(f"Tracer snap: {s}")
for line in s.traceback:
logger.debug(f"* {line}")
logger.info(f"* {line}")
@staticmethod
def log_diff(
@@ -60,9 +60,9 @@ class DanswerTracer:
) -> None:
stats = snap_current.compare_to(snap_previous, "traceback")
for s in stats[:numEntries]:
logger.debug(f"Tracer diff: {s}")
logger.info(f"Tracer diff: {s}")
for line in s.traceback.format():
logger.debug(f"* {line}")
logger.info(f"* {line}")
def log_previous_diff(self, numEntries: int) -> None:
if not self.snapshot or not self.snapshot_prev:

View File

@@ -14,6 +14,14 @@ from danswer.db.tasks import mark_task_start
from danswer.db.tasks import register_task
def name_cc_cleanup_task(connector_id: int, credential_id: int) -> str:
return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"
def name_document_set_sync_task(document_set_id: int) -> str:
return f"sync_doc_set_{document_set_id}"
def name_cc_prune_task(
connector_id: int | None = None, credential_id: int | None = None
) -> str:
@@ -85,16 +93,9 @@ def build_apply_async_wrapper(build_name_fn: Callable[..., str]) -> Callable[[AA
kwargs_for_build_name = kwargs or {}
task_name = build_name_fn(*args_for_build_name, **kwargs_for_build_name)
with Session(get_sqlalchemy_engine()) as db_session:
# register_task must come before fn = apply_async or else the task
# might run mark_task_start (and crash) before the task row exists
db_task = register_task(task_name, db_session)
# mark the task as started
task = fn(args, kwargs, *other_args, **other_kwargs)
# we update the celery task id for diagnostic purposes
# but it isn't currently used by any code
db_task.task_id = task.id
db_session.commit()
register_task(task.id, task_name, db_session)
return task

View File

@@ -17,13 +17,15 @@ from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
from danswer.db.connector import fetch_connectors
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import SqlEngine
from danswer.db.engine import init_sqlalchemy_engine
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_inprogress_index_attempts
@@ -31,14 +33,11 @@ from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import get_not_started_index_attempts
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
@@ -61,27 +60,20 @@ _UNEXPECTED_STATE_FAILURE_REASON = (
def _should_create_new_indexing(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
model: EmbeddingModel,
secondary_index_building: bool,
db_session: Session,
) -> bool:
connector = cc_pair.connector
# don't kick off indexing for `NOT_APPLICABLE` sources
if connector.source == DocumentSource.NOT_APPLICABLE:
return False
# User can still manually create single indexing attempts via the UI for the
# currently in use index
if DISABLE_INDEX_UPDATE_ON_SWAP:
if (
search_settings_instance.status == IndexModelStatus.PRESENT
and secondary_index_building
):
if model.status == IndexModelStatus.PRESENT and secondary_index_building:
return False
# When switching over models, always index at least once
if search_settings_instance.status == IndexModelStatus.FUTURE:
if model.status == IndexModelStatus.FUTURE:
if last_index:
# No new index if the last index attempt succeeded
# Once is enough. The model will never be able to swap otherwise.
@@ -96,20 +88,14 @@ def _should_create_new_indexing(
if last_index.status == IndexingStatus.IN_PROGRESS:
return False
else:
if (
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
): # Ingestion API
if connector.id == 0: # Ingestion API
return False
return True
# If the connector is paused or is the ingestion API, don't index
# NOTE: during an embedding model switch over, the following logic
# is bypassed by the above check for a future model
if (
not cc_pair.status.is_active()
or connector.id == 0
or connector.source == DocumentSource.INGESTION_API
):
if cc_pair.status == ConnectorCredentialPairStatus.PAUSED or connector.id == 0:
return False
if not last_index:
@@ -134,6 +120,16 @@ def _should_create_new_indexing(
return time_since_index.total_seconds() >= connector.refresh_freq
def _is_indexing_job_marked_as_finished(index_attempt: IndexAttempt | None) -> bool:
if index_attempt is None:
return False
return (
index_attempt.status == IndexingStatus.FAILED
or index_attempt.status == IndexingStatus.SUCCESS
)
def _mark_run_failed(
db_session: Session, index_attempt: IndexAttempt, failure_reason: str
) -> None:
@@ -174,42 +170,35 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
ongoing.add(
(
attempt.connector_credential_pair_id,
attempt.search_settings_id,
attempt.embedding_model_id,
)
)
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings = [primary_search_settings]
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings.append(secondary_search_settings)
embedding_models = [get_current_db_embedding_model(db_session)]
secondary_embedding_model = get_secondary_db_embedding_model(db_session)
if secondary_embedding_model is not None:
embedding_models.append(secondary_embedding_model)
all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair in all_connector_credential_pairs:
for search_settings_instance in search_settings:
for model in embedding_models:
# Check if there is an ongoing indexing attempt for this connector credential pair
if (cc_pair.id, search_settings_instance.id) in ongoing:
if (cc_pair.id, model.id) in ongoing:
continue
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, search_settings_instance.id, db_session
cc_pair.id, model.id, db_session
)
if not _should_create_new_indexing(
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
secondary_index_building=len(search_settings) > 1,
model=model,
secondary_index_building=len(embedding_models) > 1,
db_session=db_session,
):
continue
create_index_attempt(
cc_pair.id, search_settings_instance.id, db_session
)
create_index_attempt(cc_pair.id, model.id, db_session)
def cleanup_indexing_jobs(
@@ -217,6 +206,7 @@ def cleanup_indexing_jobs(
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
) -> dict[int, Future | SimpleJob]:
existing_jobs_copy = existing_jobs.copy()
# clean up completed jobs
with Session(get_sqlalchemy_engine()) as db_session:
for attempt_id, job in existing_jobs.items():
@@ -225,12 +215,10 @@ def cleanup_indexing_jobs(
)
# do nothing for ongoing jobs that haven't been stopped
if not job.done():
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
if not job.done() and not _is_indexing_job_marked_as_finished(
index_attempt
):
continue
if job.status == "error":
logger.error(job.exception())
@@ -305,7 +293,7 @@ def kickoff_indexing_jobs(
# get_not_started_index_attempts orders its returned results from oldest to newest
# we must process attempts in a FIFO manner to prevent connector starvation
new_indexing_attempts = [
(attempt, attempt.search_settings)
(attempt, attempt.embedding_model)
for attempt in get_not_started_index_attempts(db_session)
if attempt.id not in existing_jobs
]
@@ -317,15 +305,10 @@ def kickoff_indexing_jobs(
indexing_attempt_count = 0
primary_client_full = False
secondary_client_full = False
for attempt, search_settings in new_indexing_attempts:
if primary_client_full and secondary_client_full:
break
for attempt, embedding_model in new_indexing_attempts:
use_secondary_index = (
search_settings.status == IndexModelStatus.FUTURE
if search_settings is not None
embedding_model.status == IndexModelStatus.FUTURE
if embedding_model is not None
else False
)
if attempt.connector_credential_pair.connector is None:
@@ -347,28 +330,20 @@ def kickoff_indexing_jobs(
)
continue
if not use_secondary_index:
if not primary_client_full:
run = client.submit(
run_indexing_entrypoint,
attempt.id,
attempt.connector_credential_pair_id,
global_version.get_is_ee_version(),
pure=False,
)
if not run:
primary_client_full = True
if use_secondary_index:
run = secondary_client.submit(
run_indexing_entrypoint,
attempt.id,
global_version.get_is_ee_version(),
pure=False,
)
else:
if not secondary_client_full:
run = secondary_client.submit(
run_indexing_entrypoint,
attempt.id,
attempt.connector_credential_pair_id,
global_version.get_is_ee_version(),
pure=False,
)
if not run:
secondary_client_full = True
run = client.submit(
run_indexing_entrypoint,
attempt.id,
global_version.get_is_ee_version(),
pure=False,
)
if run:
if indexing_attempt_count == 0:
@@ -406,23 +381,18 @@ def update_loop(
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
check_index_swap(db_session=db_session)
search_settings = get_current_search_settings(db_session)
db_embedding_model = get_current_db_embedding_model(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
if search_settings.provider_type is None:
logger.notice("Running a first inference to warm up embedding model")
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
if db_embedding_model.cloud_provider_id is None:
logger.debug("Running a first inference to warm up embedding model")
warm_up_bi_encoder(
embedding_model=embedding_model,
embedding_model=db_embedding_model,
model_server_host=INDEXING_MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
logger.notice("First inference complete.")
client_primary: Client | SimpleJobClient
client_secondary: Client | SimpleJobClient
@@ -451,7 +421,6 @@ def update_loop(
existing_jobs: dict[int, Future | SimpleJob] = {}
logger.notice("Startup complete. Waiting for indexing jobs...")
while True:
start = time.time()
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
@@ -483,11 +452,9 @@ def update_loop(
def update__main() -> None:
set_is_ee_based_on_env_variable()
init_sqlalchemy_engine(POSTGRES_INDEXER_APP_NAME)
# initialize the Postgres connection pool
SqlEngine.set_app_name(POSTGRES_INDEXER_APP_NAME)
logger.notice("Starting indexing service")
logger.info("Starting indexing service")
update_loop()

View File

@@ -36,8 +36,7 @@ def create_chat_chain(
chat_session_id: int,
db_session: Session,
prefetch_tool_calls: bool = True,
# Optional id at which we finish processing
stop_at_message_id: int | None = None,
parent_id: int | None = None,
) -> tuple[ChatMessage, list[ChatMessage]]:
"""Build the linear chain of messages without including the root message"""
mainline_messages: list[ChatMessage] = []
@@ -63,12 +62,7 @@ def create_chat_chain(
current_message: ChatMessage | None = root_message
while current_message is not None:
child_msg = current_message.latest_child_message
# Break if at the end of the chain
# or have reached the `final_id` of the submitted message
if not child_msg or (
stop_at_message_id and current_message.id == stop_at_message_id
):
if not child_msg or (parent_id and current_message.id == parent_id):
break
current_message = id_to_msg.get(child_msg)

View File

@@ -6,6 +6,7 @@ from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import PERSONAS_YAML
from danswer.configs.chat_configs import PROMPTS_YAML
from danswer.db.document_set import get_or_create_document_set_by_name
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.db.models import Persona
@@ -17,32 +18,30 @@ from danswer.db.persona import upsert_prompt
from danswer.search.enums import RecencyBiasSetting
def load_prompts_from_yaml(
db_session: Session, prompts_yaml: str = PROMPTS_YAML
) -> None:
def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
with open(prompts_yaml, "r") as file:
data = yaml.safe_load(file)
all_prompts = data.get("prompts", [])
for prompt in all_prompts:
upsert_prompt(
user=None,
prompt_id=prompt.get("id"),
name=prompt["name"],
description=prompt["description"].strip(),
system_prompt=prompt["system"].strip(),
task_prompt=prompt["task"].strip(),
include_citations=prompt["include_citations"],
datetime_aware=prompt.get("datetime_aware", True),
default_prompt=True,
personas=None,
db_session=db_session,
commit=True,
)
with Session(get_sqlalchemy_engine()) as db_session:
for prompt in all_prompts:
upsert_prompt(
user=None,
prompt_id=prompt.get("id"),
name=prompt["name"],
description=prompt["description"].strip(),
system_prompt=prompt["system"].strip(),
task_prompt=prompt["task"].strip(),
include_citations=prompt["include_citations"],
datetime_aware=prompt.get("datetime_aware", True),
default_prompt=True,
personas=None,
db_session=db_session,
commit=True,
)
def load_personas_from_yaml(
db_session: Session,
personas_yaml: str = PERSONAS_YAML,
default_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
) -> None:
@@ -50,117 +49,117 @@ def load_personas_from_yaml(
data = yaml.safe_load(file)
all_personas = data.get("personas", [])
for persona in all_personas:
doc_set_names = persona["document_sets"]
doc_sets: list[DocumentSetDBModel] = [
get_or_create_document_set_by_name(db_session, name)
for name in doc_set_names
]
# Assume if user hasn't set any document sets for the persona, the user may want
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
# the document sets for the persona
doc_set_ids: list[int] | None = None
if doc_sets:
doc_set_ids = [doc_set.id for doc_set in doc_sets]
else:
doc_set_ids = None
prompt_ids: list[int] | None = None
prompt_set_names = persona["prompts"]
if prompt_set_names:
prompts: list[PromptDBModel | None] = [
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
for prompt_name in prompt_set_names
with Session(get_sqlalchemy_engine()) as db_session:
for persona in all_personas:
doc_set_names = persona["document_sets"]
doc_sets: list[DocumentSetDBModel] = [
get_or_create_document_set_by_name(db_session, name)
for name in doc_set_names
]
if any([prompt is None for prompt in prompts]):
raise ValueError("Invalid Persona configs, not all prompts exist")
if prompts:
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
# Assume if user hasn't set any document sets for the persona, the user may want
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
# the document sets for the persona
doc_set_ids: list[int] | None = None
if doc_sets:
doc_set_ids = [doc_set.id for doc_set in doc_sets]
else:
doc_set_ids = None
p_id = persona.get("id")
tool_ids = []
if persona.get("image_generation"):
image_gen_tool = (
db_session.query(ToolDBModel)
.filter(ToolDBModel.name == "ImageGenerationTool")
prompt_ids: list[int] | None = None
prompt_set_names = persona["prompts"]
if prompt_set_names:
prompts: list[PromptDBModel | None] = [
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
for prompt_name in prompt_set_names
]
if any([prompt is None for prompt in prompts]):
raise ValueError("Invalid Persona configs, not all prompts exist")
if prompts:
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
p_id = persona.get("id")
tool_ids = []
if persona.get("image_generation"):
image_gen_tool = (
db_session.query(ToolDBModel)
.filter(ToolDBModel.name == "ImageGenerationTool")
.first()
)
if image_gen_tool:
tool_ids.append(image_gen_tool.id)
llm_model_provider_override = persona.get("llm_model_provider_override")
llm_model_version_override = persona.get("llm_model_version_override")
# Set specific overrides for image generation persona
if persona.get("image_generation"):
llm_model_version_override = "gpt-4o"
existing_persona = (
db_session.query(Persona)
.filter(Persona.name == persona["name"])
.first()
)
if image_gen_tool:
tool_ids.append(image_gen_tool.id)
llm_model_provider_override = persona.get("llm_model_provider_override")
llm_model_version_override = persona.get("llm_model_version_override")
# Set specific overrides for image generation persona
if persona.get("image_generation"):
llm_model_version_override = "gpt-4o"
existing_persona = (
db_session.query(Persona).filter(Persona.name == persona["name"]).first()
)
upsert_persona(
user=None,
persona_id=(-1 * p_id) if p_id is not None else None,
name=persona["name"],
description=persona["description"],
num_chunks=persona.get("num_chunks")
if persona.get("num_chunks") is not None
else default_chunks,
llm_relevance_filter=persona.get("llm_relevance_filter"),
starter_messages=persona.get("starter_messages"),
llm_filter_extraction=persona.get("llm_filter_extraction"),
icon_shape=persona.get("icon_shape"),
icon_color=persona.get("icon_color"),
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
prompt_ids=prompt_ids,
document_set_ids=doc_set_ids,
tool_ids=tool_ids,
builtin_persona=True,
is_public=True,
display_priority=existing_persona.display_priority
if existing_persona is not None
else persona.get("display_priority"),
is_visible=existing_persona.is_visible
if existing_persona is not None
else persona.get("is_visible"),
db_session=db_session,
)
upsert_persona(
user=None,
persona_id=(-1 * p_id) if p_id is not None else None,
name=persona["name"],
description=persona["description"],
num_chunks=persona.get("num_chunks")
if persona.get("num_chunks") is not None
else default_chunks,
llm_relevance_filter=persona.get("llm_relevance_filter"),
starter_messages=persona.get("starter_messages"),
llm_filter_extraction=persona.get("llm_filter_extraction"),
icon_shape=persona.get("icon_shape"),
icon_color=persona.get("icon_color"),
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
prompt_ids=prompt_ids,
document_set_ids=doc_set_ids,
tool_ids=tool_ids,
default_persona=True,
is_public=True,
display_priority=existing_persona.display_priority
if existing_persona is not None
else persona.get("display_priority"),
is_visible=existing_persona.is_visible
if existing_persona is not None
else persona.get("is_visible"),
db_session=db_session,
)
def load_input_prompts_from_yaml(
db_session: Session, input_prompts_yaml: str = INPUT_PROMPT_YAML
) -> None:
def load_input_prompts_from_yaml(input_prompts_yaml: str = INPUT_PROMPT_YAML) -> None:
with open(input_prompts_yaml, "r") as file:
data = yaml.safe_load(file)
all_input_prompts = data.get("input_prompts", [])
for input_prompt in all_input_prompts:
# If these prompts are deleted (which is a hard delete in the DB), on server startup
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
insert_input_prompt_if_not_exists(
user=None,
input_prompt_id=input_prompt.get("id"),
prompt=input_prompt["prompt"],
content=input_prompt["content"],
is_public=input_prompt["is_public"],
active=input_prompt.get("active", True),
db_session=db_session,
commit=True,
)
with Session(get_sqlalchemy_engine()) as db_session:
for input_prompt in all_input_prompts:
# If these prompts are deleted (which is a hard delete in the DB), on server startup
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
insert_input_prompt_if_not_exists(
user=None,
input_prompt_id=input_prompt.get("id"),
prompt=input_prompt["prompt"],
content=input_prompt["content"],
is_public=input_prompt["is_public"],
active=input_prompt.get("active", True),
db_session=db_session,
commit=True,
)
def load_chat_yamls(
db_session: Session,
prompt_yaml: str = PROMPTS_YAML,
personas_yaml: str = PERSONAS_YAML,
input_prompts_yaml: str = INPUT_PROMPT_YAML,
) -> None:
load_prompts_from_yaml(db_session, prompt_yaml)
load_personas_from_yaml(db_session, personas_yaml)
load_input_prompts_from_yaml(db_session, input_prompts_yaml)
load_prompts_from_yaml(prompt_yaml)
load_personas_from_yaml(personas_yaml)
load_input_prompts_from_yaml(input_prompts_yaml)

View File

@@ -1,6 +1,5 @@
from collections.abc import Iterator
from datetime import datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel
@@ -10,7 +9,6 @@ from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import RetrievalDocs
from danswer.search.models import SearchResponse
from danswer.tools.custom.base_tool_types import ToolResultType
class LlmDoc(BaseModel):
@@ -36,35 +34,16 @@ class QADocsResponse(RetrievalDocs):
applied_time_cutoff: datetime | None
recency_bias_multiplier: float
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().dict(*args, **kwargs) # type: ignore
initial_dict["applied_time_cutoff"] = (
self.applied_time_cutoff.isoformat() if self.applied_time_cutoff else None
)
return initial_dict
class StreamStopReason(Enum):
CONTEXT_LENGTH = "context_length"
CANCELLED = "cancelled"
class StreamStopInfo(BaseModel):
stop_reason: StreamStopReason
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
data = super().model_dump(mode="json", *args, **kwargs) # type: ignore
data["stop_reason"] = self.stop_reason.name
return data
class LLMRelevanceFilterResponse(BaseModel):
llm_selected_doc_indices: list[int]
class FinalUsedContextDocsResponse(BaseModel):
final_context_docs: list[LlmDoc]
relevant_chunk_indices: list[int]
class RelevanceAnalysis(BaseModel):
@@ -85,6 +64,10 @@ class DocumentRelevance(BaseModel):
relevance_summaries: dict[str, RelevanceAnalysis]
class Delimiter(BaseModel):
delimiter: bool
class DanswerAnswerPiece(BaseModel):
# A small piece of a complete answer. Used for streaming back answers.
answer_piece: str | None # if None, specifies the end of an Answer
@@ -97,16 +80,6 @@ class CitationInfo(BaseModel):
document_id: str
class AllCitations(BaseModel):
citations: list[CitationInfo]
# This is a mapping of the citation number to the document index within
# the result search doc set
class MessageSpecificCitations(BaseModel):
citation_map: dict[int, int]
class MessageResponseIDInfo(BaseModel):
user_message_id: int | None
reserved_assistant_message_id: int
@@ -152,7 +125,7 @@ class QAResponse(SearchResponse, DanswerAnswer):
predicted_flow: QueryFlow
predicted_search: SearchType
eval_res_valid: bool | None = None
llm_selected_doc_indices: list[int] | None = None
llm_chunks_indices: list[int] | None = None
error_msg: str | None = None
@@ -161,7 +134,7 @@ class ImageGenerationDisplay(BaseModel):
class CustomToolResponse(BaseModel):
response: ToolResultType
response: dict
tool_name: str
@@ -173,7 +146,7 @@ AnswerQuestionPossibleReturn = (
| ImageGenerationDisplay
| CustomToolResponse
| StreamingError
| StreamStopInfo
| Delimiter
)

View File

@@ -1,4 +1,3 @@
import traceback
from collections.abc import Callable
from collections.abc import Iterator
from functools import partial
@@ -7,15 +6,13 @@ from typing import cast
from sqlalchemy.orm import Session
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.models import AllCitations
from danswer.chat.models import CitationInfo
from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import FinalUsedContextDocsResponse
from danswer.chat.models import Delimiter
from danswer.chat.models import ImageGenerationDisplay
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import MessageResponseIDInfo
from danswer.chat.models import MessageSpecificCitations
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import BING_API_KEY
@@ -35,13 +32,13 @@ from danswer.db.chat import get_or_create_root_message
from danswer.db.chat import reserve_message_id
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
from danswer.db.persona import get_persona_by_id
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.factory import get_default_document_index
from danswer.file_store.models import ChatFileType
from danswer.file_store.models import FileDescriptor
@@ -73,9 +70,7 @@ from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from danswer.tools.custom.custom_tool import CustomToolCallSummary
from danswer.tools.force import ForceUseTool
@@ -90,15 +85,13 @@ from danswer.tools.internet_search.internet_search_tool import (
)
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.tool_runner import ToolCallMetadata
from danswer.tools.utils import compute_all_tool_tokens
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.logger import setup_logger
@@ -107,9 +100,9 @@ from danswer.utils.timing import log_generator_function_time
logger = setup_logger()
def _translate_citations(
def translate_citations(
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
) -> MessageSpecificCitations:
) -> dict[int, int]:
"""Always cites the first instance of the document_id, assumes the db_docs
are sorted in the order displayed in the UI"""
doc_id_to_saved_doc_id_map: dict[str, int] = {}
@@ -124,7 +117,7 @@ def _translate_citations(
citation.citation_num
] = doc_id_to_saved_doc_id_map[citation.document_id]
return MessageSpecificCitations(citation_map=citation_to_saved_doc_id_map)
return citation_to_saved_doc_id_map
def _handle_search_tool_response_summary(
@@ -246,15 +239,13 @@ ChatPacket = (
StreamingError
| QADocsResponse
| LLMRelevanceFilterResponse
| FinalUsedContextDocsResponse
| ChatMessageDetail
| DanswerAnswerPiece
| AllCitations
| CitationInfo
| ImageGenerationDisplay
| CustomToolResponse
| MessageSpecificCitations
| MessageResponseIDInfo
| Delimiter
)
ChatPacketStream = Iterator[ChatPacket]
@@ -273,7 +264,6 @@ def stream_chat_message_objects(
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
enforce_chat_session_id_for_search_docs: bool = True,
) -> ChatPacketStream:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
@@ -281,11 +271,6 @@ def stream_chat_message_objects(
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
# Currently surrounding context is not supported for chat
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
new_msg_req.chunks_above = 0
new_msg_req.chunks_below = 0
try:
user_id = user.id if user is not None else None
@@ -342,9 +327,9 @@ def stream_chat_message_objects(
Callable[[str], list[int]], llm_tokenizer.encode
)
search_settings = get_current_search_settings(db_session)
embedding_model = get_current_db_embedding_model(db_session)
document_index = get_default_document_index(
primary_index_name=search_settings.index_name, secondary_index_name=None
primary_index_name=embedding_model.index_name, secondary_index_name=None
)
# Every chat Session begins with an empty root message
@@ -365,7 +350,7 @@ def stream_chat_message_objects(
if new_msg_req.regenerate:
final_msg, history_msgs = create_chat_chain(
stop_at_message_id=parent_id,
parent_id=parent_id,
chat_session_id=chat_session_id,
db_session=db_session,
)
@@ -445,7 +430,6 @@ def stream_chat_message_objects(
chat_session=chat_session,
user_id=user_id,
db_session=db_session,
enforce_chat_session_id_for_search_docs=enforce_chat_session_id_for_search_docs,
)
# Generates full documents currently
@@ -477,6 +461,8 @@ def stream_chat_message_objects(
else default_num_chunks
),
max_window_percentage=max_document_percentage,
use_sections=new_msg_req.chunks_above > 0
or new_msg_req.chunks_below > 0,
)
reserved_message_id = reserve_message_id(
db_session=db_session,
@@ -491,17 +477,16 @@ def stream_chat_message_objects(
reserved_assistant_message_id=reserved_message_id,
)
overridden_model = (
alternate_model = (
new_msg_req.llm_override.model_version if new_msg_req.llm_override else None
)
# Cannot determine these without the LLM step or breaking out early
partial_response = partial(
create_new_chat_message,
chat_session_id=chat_session_id,
parent_message=final_msg,
prompt_id=prompt_id,
overridden_model=overridden_model,
alternate_model=alternate_model,
# message=,
# rephrased_query=,
# token_count=,
@@ -609,13 +594,8 @@ def stream_chat_message_objects(
if db_tool_model.openapi_schema:
tool_dict[db_tool_model.id] = cast(
list[Tool],
build_custom_tools_from_openapi_schema_and_headers(
db_tool_model.openapi_schema,
dynamic_schema_info=DynamicSchemaInfo(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
),
custom_headers=db_tool_model.custom_headers,
build_custom_tools_from_openapi_schema(
db_tool_model.openapi_schema
),
)
@@ -630,6 +610,7 @@ def stream_chat_message_objects(
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
llm_provider, llm_model_name
)
tool_has_been_called = False # TODO remove
# LLM prompt building, response capturing, etc.
answer = Answer(
@@ -670,6 +651,8 @@ def stream_chat_message_objects(
for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse):
tool_has_been_called = True
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
(
qa_docs_response,
@@ -680,11 +663,9 @@ def stream_chat_message_objects(
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
dedupe_docs=(
retrieval_options.dedupe_docs
if retrieval_options
else False
),
dedupe_docs=retrieval_options.dedupe_docs
if retrieval_options
else False,
)
yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
@@ -707,14 +688,9 @@ def stream_chat_message_objects(
)
yield LLMRelevanceFilterResponse(
llm_selected_doc_indices=llm_indices
relevant_chunk_indices=llm_indices
)
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(
final_context_docs=packet.response
)
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], packet.response
@@ -747,89 +723,137 @@ def stream_chat_message_objects(
)
else:
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
yield cast(ChatPacket, packet)
logger.debug("Reached end of stream")
except ValueError as e:
logger.exception("Failed to process chat message.")
if isinstance(packet, Delimiter):
db_citations = None
error_msg = str(e)
yield StreamingError(error=error_msg)
db_session.rollback()
return
if reference_db_search_docs:
db_citations = translate_citations(
citations_list=answer.citations,
db_docs=reference_db_search_docs,
)
# Saving Gen AI answer and responding with message info
tool_name_to_tool_id: dict[str, int] = {}
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name] = tool_id
if tool_result is None:
tool_call = None
else:
tool_call = ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
gen_ai_response_message = partial_response(
message=answer.llm_answer,
rephrased_query=(
qa_docs_response.rephrased_query
if qa_docs_response
else None
),
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=db_citations,
error=None,
tool_call=tool_call,
)
db_session.commit() # actually save user / assistant message
msg_detail_response = translate_db_message_to_chat_message_detail(
gen_ai_response_message
)
yield msg_detail_response
yield Delimiter(delimiter=True)
partial_response = partial(
create_new_chat_message,
chat_session_id=chat_session_id,
parent_message=gen_ai_response_message,
prompt_id=prompt_id,
# message=,
# rephrased_query=,
# token_count=,
message_type=MessageType.ASSISTANT,
alternate_assistant_id=new_msg_req.alternate_assistant_id,
# error=,
# reference_docs=,
db_session=db_session,
commit=False,
)
else:
if isinstance(packet, ToolCallMetadata):
tool_result = packet
yield cast(ChatPacket, packet)
logger.debug("Reached end of stream")
except Exception as e:
logger.exception("Failed to process chat message.")
error_msg = str(e)
stack_trace = traceback.format_exc()
logger.exception(f"Failed to process chat message: {error_msg}")
client_error_msg = litellm_exception_to_error_msg(e, llm)
if llm.config.api_key and len(llm.config.api_key) > 2:
error_msg = error_msg.replace(llm.config.api_key, "[REDACTED_API_KEY]")
stack_trace = stack_trace.replace(llm.config.api_key, "[REDACTED_API_KEY]")
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
yield StreamingError(error=client_error_msg, stack_trace=error_msg)
db_session.rollback()
return
# Post-LLM answer processing
try:
message_specific_citations: MessageSpecificCitations | None = None
if reference_db_search_docs:
message_specific_citations = _translate_citations(
citations_list=answer.citations,
db_docs=reference_db_search_docs,
)
yield AllCitations(citations=answer.citations)
if not tool_has_been_called:
try:
db_citations = None
if reference_db_search_docs:
db_citations = translate_citations(
citations_list=answer.citations,
db_docs=reference_db_search_docs,
)
# Saving Gen AI answer and responding with message info
tool_name_to_tool_id: dict[str, int] = {}
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name] = tool_id
# Saving Gen AI answer and responding with message info
tool_name_to_tool_id = {}
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name] = tool_id
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
message=answer.llm_answer,
rephrased_query=(
qa_docs_response.rephrased_query if qa_docs_response else None
),
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=message_specific_citations.citation_map
if message_specific_citations
else None,
error=None,
tool_calls=(
[
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
message=answer.llm_answer,
rephrased_query=(
qa_docs_response.rephrased_query if qa_docs_response else None
),
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=db_citations,
error=None,
tool_call=ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
if tool_result
else []
),
)
else None,
)
logger.debug("Committing messages")
db_session.commit() # actually save user / assistant message
logger.debug("Committing messages")
db_session.commit() # actually save user / assistant message
msg_detail_response = translate_db_message_to_chat_message_detail(
gen_ai_response_message
)
msg_detail_response = translate_db_message_to_chat_message_detail(
gen_ai_response_message
)
yield msg_detail_response
except Exception as e:
error_msg = str(e)
logger.exception(error_msg)
yield msg_detail_response
except Exception as e:
error_msg = str(e)
logger.exception(error_msg)
# Frontend will erase whatever answer and show this instead
yield StreamingError(error="Failed to parse LLM output")
# Frontend will erase whatever answer and show this instead
yield StreamingError(error="Failed to parse LLM output")
@log_generator_function_time()
@@ -850,4 +874,4 @@ def stream_chat_message(
is_connected=is_connected,
)
for obj in objects:
yield get_json_line(obj.model_dump())
yield get_json_line(obj.dict())

View File

@@ -42,7 +42,8 @@ prompts:
task: >
Generate an image based on the user's description.
Provide a detailed description of the generated image, including key elements, colors, and composition.
Provide a detailed description of the generated image, including key elements, colors, and composition.
If the request is not possible or appropriate, explain why and suggest alternatives.
datetime_aware: true

View File

@@ -1,4 +1,4 @@
from typing_extensions import TypedDict # noreorder
from typing import TypedDict
from pydantic import BaseModel

View File

@@ -93,14 +93,6 @@ SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER
# If set, Danswer will listen to the `expires_at` returned by the identity
# provider (e.g. Okta, Google, etc.) and force the user to re-authenticate
# after this time has elapsed. Disabled since by default many auth providers
# have very short expiry times (e.g. 1 hour) which provide a poor user experience
TRACK_EXTERNAL_IDP_EXPIRY = (
os.environ.get("TRACK_EXTERNAL_IDP_EXPIRY", "").lower() == "true"
)
#####
# DB Configs
@@ -126,7 +118,6 @@ try:
except ValueError:
INDEX_BATCH_SIZE = 16
# Below are intended to match the env variables names used by the official postgres docker image
# https://hub.docker.com/_/postgres
POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres"
@@ -138,12 +129,6 @@ POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
POSTGRES_API_SERVER_POOL_SIZE = int(
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40
)
POSTGRES_API_SERVER_POOL_OVERFLOW = int(
os.environ.get("POSTGRES_API_SERVER_POOL_OVERFLOW") or 10
)
# defaults to False
POSTGRES_POOL_PRE_PING = os.environ.get("POSTGRES_POOL_PRE_PING", "").lower() == "true"
@@ -156,43 +141,6 @@ try:
except ValueError:
POSTGRES_POOL_RECYCLE = POSTGRES_POOL_RECYCLE_DEFAULT
REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true"
REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
# Used for general redis things
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))
# Used by celery as broker and backend
REDIS_DB_NUMBER_CELERY_RESULT_BACKEND = int(
os.environ.get("REDIS_DB_NUMBER_CELERY_RESULT_BACKEND", 14)
)
REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15)) # broker
# will propagate to both our redis client as well as celery's redis client
REDIS_HEALTH_CHECK_INTERVAL = int(os.environ.get("REDIS_HEALTH_CHECK_INTERVAL", 60))
# our redis client only, not celery's
REDIS_POOL_MAX_CONNECTIONS = int(os.environ.get("REDIS_POOL_MAX_CONNECTIONS", 128))
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
# should be one of "required", "optional", or "none"
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "none")
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", None)
CELERY_RESULT_EXPIRES = int(os.environ.get("CELERY_RESULT_EXPIRES", 86400)) # seconds
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#broker-pool-limit
# Setting to None may help when there is a proxy in the way closing idle connections
CELERY_BROKER_POOL_LIMIT_DEFAULT = 10
try:
CELERY_BROKER_POOL_LIMIT = int(
os.environ.get("CELERY_BROKER_POOL_LIMIT", CELERY_BROKER_POOL_LIMIT_DEFAULT)
)
except ValueError:
CELERY_BROKER_POOL_LIMIT = CELERY_BROKER_POOL_LIMIT_DEFAULT
#####
# Connector Configs
#####
@@ -244,8 +192,8 @@ CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [
]
# Avoid to get archived pages
CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES = (
os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES", "").lower() == "true"
CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES = (
os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES", "").lower() == "true"
)
# Save pages labels as Danswer metadata tags
@@ -256,12 +204,7 @@ CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING = (
# Attachments exceeding this size will not be retrieved (in bytes)
CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD = int(
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD", 10 * 1024 * 1024)
)
# Attachments with more chars than this will not be indexed. This is to prevent extremely
# large files from freezing indexing. 200,000 is ~100 google doc pages.
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000)
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD", 50 * 1024 * 1024)
)
JIRA_CONNECTOR_LABELS_TO_SKIP = [
@@ -269,10 +212,6 @@ JIRA_CONNECTOR_LABELS_TO_SKIP = [
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
if ignored_tag
]
# Maximum size for Jira tickets in bytes (default: 100KB)
JIRA_CONNECTOR_MAX_TICKET_SIZE = int(
os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024)
)
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
@@ -296,7 +235,7 @@ ALLOW_SIMULTANEOUS_PRUNING = (
os.environ.get("ALLOW_SIMULTANEOUS_PRUNING", "").lower() == "true"
)
# This is the maximum rate at which documents are queried for a pruning job. 0 disables the limitation.
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE = int(
os.environ.get("MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE", 0)
)
@@ -356,14 +295,12 @@ INDEXING_SIZE_WARNING_THRESHOLD = int(
# 0 disables this behavior and is the default.
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL", 0))
# During an indexing attempt, specifies the number of batches which are allowed to
# exception without aborting the attempt.
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT", 0))
#####
# Miscellaneous
#####
# File based Key Value store no longer used
DYNAMIC_CONFIG_STORE = "PostgresBackedDynamicConfigStore"
JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
# used to allow the background indexing jobs to use a different embedding
# model server than the API server
@@ -384,10 +321,6 @@ LOG_VESPA_TIMING_INFORMATION = (
os.environ.get("LOG_VESPA_TIMING_INFORMATION", "").lower() == "true"
)
LOG_ENDPOINT_LATENCY = os.environ.get("LOG_ENDPOINT_LATENCY", "").lower() == "true"
LOG_POSTGRES_LATENCY = os.environ.get("LOG_POSTGRES_LATENCY", "").lower() == "true"
LOG_POSTGRES_CONN_COUNTS = (
os.environ.get("LOG_POSTGRES_CONN_COUNTS", "").lower() == "true"
)
# Anonymous usage telemetry
DISABLE_TELEMETRY = os.environ.get("DISABLE_TELEMETRY", "").lower() == "true"
@@ -412,11 +345,3 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
ENTERPRISE_EDITION_ENABLED = (
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
)
MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "")
DATA_PLANE_SECRET = os.environ.get("DATA_PLANE_SECRET", "")
EXPECTED_API_KEY = os.environ.get("EXPECTED_API_KEY", "")

View File

@@ -31,9 +31,8 @@ FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
# For the highest matching base size chunk, how many chunks above and below do we pull in by default
# Note this is not in any of the deployment configs yet
# Currently only applies to search flow not chat
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 1)
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 1)
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 0)
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 0)
# Whether the LLM should be used to decide if a search would help given the chat history
DISABLE_LLM_CHOOSE_SEARCH = (
os.environ.get("DISABLE_LLM_CHOOSE_SEARCH", "").lower() == "true"
@@ -45,7 +44,7 @@ DISABLE_LLM_QUERY_REPHRASE = (
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.5)))
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.62)))
HYBRID_ALPHA_KEYWORD = max(
0, min(1, float(os.environ.get("HYBRID_ALPHA_KEYWORD") or 0.4))
)
@@ -54,7 +53,7 @@ HYBRID_ALPHA_KEYWORD = max(
# Content. This is to avoid cases where the Content is very relevant but it may not be clear
# if the title is separated out. Title is most of a "boost" than a separate field.
TITLE_CONTENT_RATIO = max(
0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.10))
0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.20))
)
# A list of languages passed to the LLM to rephase the query
@@ -83,15 +82,8 @@ DISABLE_LLM_DOC_RELEVANCE = (
# Stops streaming answers back to the UI if this pattern is seen:
STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None
# Set this to "true" to hard delete chats
# This will make chats unviewable by admins after a user deletes them
# As opposed to soft deleting them, which just hides them from non-admin users
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "").lower() == "true"
# The backend logic for this being True isn't fully supported yet
HARD_DELETE_CHATS = False
# Internet Search
BING_API_KEY = os.environ.get("BING_API_KEY") or None
# Enable in-house model for detecting connector-based filtering in queries
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)
VESPA_SEARCHER_THREADS = int(os.environ.get("VESPA_SEARCHER_THREADS") or 2)

View File

@@ -1,6 +1,3 @@
import platform
import socket
from enum import auto
from enum import Enum
SOURCE_TYPE = "source_type"
@@ -15,6 +12,10 @@ ID_SEPARATOR = ":;:"
DEFAULT_BOOST = 0
SESSION_KEY = "session"
# For tool calling
MAXIMUM_TOOL_CALL_SEQUENCE = 5
# For chunking/processing chunks
RETURN_SEPARATOR = "\n\r\n"
SECTION_SEPARATOR = "\n\n"
@@ -36,12 +37,9 @@ POSTGRES_WEB_APP_NAME = "web"
POSTGRES_INDEXER_APP_NAME = "indexer"
POSTGRES_CELERY_APP_NAME = "celery"
POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_CELERY_WORKER_APP_NAME = "celery_worker"
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"
POSTGRES_DEFAULT_SCHEMA = "public"
# API Keys
DANSWER_API_KEY_PREFIX = "API_KEY__"
@@ -51,7 +49,6 @@ UNNAMED_KEY_PLACEHOLDER = "Unnamed"
# Key-Value store keys
KV_REINDEX_KEY = "needs_reindexing"
KV_SEARCH_SETTINGS = "search_settings"
KV_UNSTRUCTURED_API_KEY = "unstructured_api_key"
KV_USER_STORE_KEY = "INVITED_USERS"
KV_NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
KV_CRED_KEY = "credential_id_{}"
@@ -63,13 +60,9 @@ KV_SLACK_BOT_TOKENS_CONFIG_KEY = "slack_bot_tokens_config_key"
KV_GEN_AI_KEY_CHECK_TIME = "genai_api_key_last_check_time"
KV_SETTINGS_KEY = "danswer_settings"
KV_CUSTOMER_UUID_KEY = "customer_uuid"
KV_INSTANCE_DOMAIN_KEY = "instance_domain"
KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings"
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
class DocumentSource(str, Enum):
# Special case, document passed in via Danswer APIs without specifying a source type
@@ -106,12 +99,10 @@ class DocumentSource(str, Enum):
CLICKUP = "clickup"
MEDIAWIKI = "mediawiki"
WIKIPEDIA = "wikipedia"
ASANA = "asana"
S3 = "s3"
R2 = "r2"
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
OCI_STORAGE = "oci_storage"
XENFORO = "xenforo"
NOT_APPLICABLE = "not_applicable"
@@ -142,12 +133,6 @@ class AuthType(str, Enum):
SAML = "saml"
class SessionType(str, Enum):
CHAT = "Chat"
SEARCH = "Search"
SLACK = "Slack"
class QAFeedbackType(str, Enum):
LIKE = "like" # User likes the answer, used for metrics
DISLIKE = "dislike" # User dislikes the answer, used for metrics
@@ -180,39 +165,3 @@ class FileOrigin(str, Enum):
CONNECTOR = "connector"
GENERATED_REPORT = "generated_report"
OTHER = "other"
class PostgresAdvisoryLocks(Enum):
KOMBU_MESSAGE_CLEANUP_LOCK_ID = auto()
class DanswerCeleryQueues:
VESPA_METADATA_SYNC = "vespa_metadata_sync"
CONNECTOR_DELETION = "connector_deletion"
CONNECTOR_PRUNING = "connector_pruning"
class DanswerRedisLocks:
PRIMARY_WORKER = "da_lock:primary_worker"
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
class DanswerCeleryPriority(int, Enum):
HIGHEST = 0
HIGH = auto()
MEDIUM = auto()
LOW = auto()
LOWEST = auto()
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3
if platform.system() == "Darwin":
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPALIVE] = 60 # type: ignore
else:
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60 # type: ignore

View File

@@ -73,15 +73,3 @@ DANSWER_BOT_FEEDBACK_REMINDER = int(
DANSWER_BOT_REPHRASE_MESSAGE = (
os.environ.get("DANSWER_BOT_REPHRASE_MESSAGE", "").lower() == "true"
)
# DANSWER_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD is the number of
# responses DanswerBot can send in a given time period.
# Set to 0 to disable the limit.
DANSWER_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD = int(
os.environ.get("DANSWER_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD", "5000")
)
# DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS is the number
# of seconds until the response limit is reset.
DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS = int(
os.environ.get("DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS", "86400")
)

View File

@@ -39,13 +39,9 @@ SIM_SCORE_RANGE_HIGH = float(os.environ.get("SIM_SCORE_RANGE_HIGH") or 1.0)
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "search_query: ")
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "search_document: ")
# Purely an optimization, memory limitation consideration
# User's set embedding batch size overrides the default encoding batch sizes
EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE") or 0) or None
BATCH_SIZE_ENCODE_CHUNKS = EMBEDDING_BATCH_SIZE or 8
BATCH_SIZE_ENCODE_CHUNKS = 8
# don't send over too many chunks at once, as sending too many could cause timeouts
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES = EMBEDDING_BATCH_SIZE or 512
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES = 512
# For score display purposes, only way is to know the expected ranges
CROSS_ENCODER_RANGE_MAX = 1
CROSS_ENCODER_RANGE_MIN = 0
@@ -55,23 +51,37 @@ CROSS_ENCODER_RANGE_MIN = 0
# Generative AI Model Configs
#####
# NOTE: the 3 below should only be used for dev.
GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY")
# If changing GEN_AI_MODEL_PROVIDER or GEN_AI_MODEL_VERSION from the default,
# be sure to use one that is LiteLLM compatible:
# https://litellm.vercel.app/docs/providers/azure#completion---using-env-variables
# The provider is the prefix before / in the model argument
# Additionally Danswer supports GPT4All and custom request library based models
# Set GEN_AI_MODEL_PROVIDER to "custom" to use the custom requests approach
# Set GEN_AI_MODEL_PROVIDER to "gpt4all" to use gpt4all models running locally
GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai"
# If using Azure, it's the engine name, for example: Danswer
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION")
# For secondary flows like extracting filters or deciding if a chunk is useful, we don't need
# as powerful of a model as say GPT-4 so we can use an alternative that is faster and cheaper
FAST_GEN_AI_MODEL_VERSION = os.environ.get("FAST_GEN_AI_MODEL_VERSION")
# Override the auto-detection of LLM max context length
GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None
# Set this to be enough for an answer + quotes. Also used for Chat
# This is the minimum token context we will leave for the LLM to generate an answer
GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int(
os.environ.get("GEN_AI_NUM_RESERVED_OUTPUT_TOKENS") or 1024
# If the Generative AI model requires an API key for access, otherwise can leave blank
GEN_AI_API_KEY = (
os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY")) or None
)
# Typically, GenAI models nowadays are at least 4K tokens
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096
# API Base, such as (for Azure): https://danswer.openai.azure.com/
GEN_AI_API_ENDPOINT = os.environ.get("GEN_AI_API_ENDPOINT") or None
# API Version, such as (for Azure): 2023-09-15-preview
GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None
# LiteLLM custom_llm_provider
GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None
# Override the auto-detection of LLM max context length
GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None
# Set this to be enough for an answer + quotes. Also used for Chat
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024)
# Number of tokens from chat history to include at maximum
# 3000 should be enough context regardless of use, no need to include as much as possible
# as this drives up the cost unnecessarily

View File

@@ -59,8 +59,6 @@ if __name__ == "__main__":
latest_docs = test_connector.poll_source(one_day_ago, current)
```
> Note: Be sure to set PYTHONPATH to danswer/backend before running the above main.
### Additional Required Changes:
#### Backend Changes
@@ -70,16 +68,17 @@ if __name__ == "__main__":
[here](https://github.com/danswer-ai/danswer/blob/main/backend/danswer/connectors/factory.py#L33)
#### Frontend Changes
- Add the new Connector definition to the `SOURCE_METADATA_MAP` [here](https://github.com/danswer-ai/danswer/blob/main/web/src/lib/sources.ts#L59).
- Add the definition for the new Form to the `connectorConfigs` object [here](https://github.com/danswer-ai/danswer/blob/main/web/src/lib/connectors/connectors.ts#L79).
- Create the new connector directory and admin page under `danswer/web/src/app/admin/connectors/`
- Create the new icon, type, source, and filter changes
(refer to existing [PR](https://github.com/danswer-ai/danswer/pull/139))
#### Docs Changes
Create the new connector page (with guiding images!) with how to get the connector credentials and how to set up the
connector in Danswer. Then create a Pull Request in https://github.com/danswer-ai/danswer-docs.
connector in Danswer. Then create a Pull Request in https://github.com/danswer-ai/danswer-docs
### Before opening PR
1. Be sure to fully test changes end to end with setting up the connector and updating the index with new docs from the
new connector. To make it easier to review, please attach a video showing the successful creation of the connector via the UI (starting from the `Add Connector` page).
2. Add a folder + tests under `backend/tests/daily/connectors` director. For an example, checkout the [test for Confluence](https://github.com/danswer-ai/danswer/blob/main/backend/tests/daily/connectors/confluence/test_confluence_basic.py). In the PR description, include a guide on how to setup the new source to pass the test. Before merging, we will re-create the environment and make sure the test(s) pass.
3. Be sure to run the linting/formatting, refer to the formatting and linting section in
new connector.
2. Be sure to run the linting/formatting, refer to the formatting and linting section in
[CONTRIBUTING.md](https://github.com/danswer-ai/danswer/blob/main/CONTRIBUTING.md#formatting-and-linting)

View File

@@ -1,233 +0,0 @@
import time
from collections.abc import Iterator
from datetime import datetime
from typing import Dict
import asana # type: ignore
from danswer.utils.logger import setup_logger
logger = setup_logger()
# https://github.com/Asana/python-asana/tree/master?tab=readme-ov-file#documentation-for-api-endpoints
class AsanaTask:
def __init__(
self,
id: str,
title: str,
text: str,
link: str,
last_modified: datetime,
project_gid: str,
project_name: str,
) -> None:
self.id = id
self.title = title
self.text = text
self.link = link
self.last_modified = last_modified
self.project_gid = project_gid
self.project_name = project_name
def __str__(self) -> str:
return f"ID: {self.id}\nTitle: {self.title}\nLast modified: {self.last_modified}\nText: {self.text}"
class AsanaAPI:
def __init__(
self, api_token: str, workspace_gid: str, team_gid: str | None
) -> None:
self._user = None # type: ignore
self.workspace_gid = workspace_gid
self.team_gid = team_gid
self.configuration = asana.Configuration()
self.api_client = asana.ApiClient(self.configuration)
self.tasks_api = asana.TasksApi(self.api_client)
self.stories_api = asana.StoriesApi(self.api_client)
self.users_api = asana.UsersApi(self.api_client)
self.project_api = asana.ProjectsApi(self.api_client)
self.workspaces_api = asana.WorkspacesApi(self.api_client)
self.api_error_count = 0
self.configuration.access_token = api_token
self.task_count = 0
def get_tasks(
self, project_gids: list[str] | None, start_date: str
) -> Iterator[AsanaTask]:
"""Get all tasks from the projects with the given gids that were modified since the given date.
If project_gids is None, get all tasks from all projects in the workspace."""
logger.info("Starting to fetch Asana projects")
projects = self.project_api.get_projects(
opts={
"workspace": self.workspace_gid,
"opt_fields": "gid,name,archived,modified_at",
}
)
start_seconds = int(time.mktime(datetime.now().timetuple()))
projects_list = []
project_count = 0
for project_info in projects:
project_gid = project_info["gid"]
if project_gids is None or project_gid in project_gids:
projects_list.append(project_gid)
else:
logger.debug(
f"Skipping project: {project_gid} - not in accepted project_gids"
)
project_count += 1
if project_count % 100 == 0:
logger.info(f"Processed {project_count} projects")
logger.info(f"Found {len(projects_list)} projects to process")
for project_gid in projects_list:
for task in self._get_tasks_for_project(
project_gid, start_date, start_seconds
):
yield task
logger.info(f"Completed fetching {self.task_count} tasks from Asana")
if self.api_error_count > 0:
logger.warning(
f"Encountered {self.api_error_count} API errors during task fetching"
)
def _get_tasks_for_project(
self, project_gid: str, start_date: str, start_seconds: int
) -> Iterator[AsanaTask]:
project = self.project_api.get_project(project_gid, opts={})
if project["archived"]:
logger.info(f"Skipping archived project: {project['name']} ({project_gid})")
return []
if not project["team"] or not project["team"]["gid"]:
logger.info(
f"Skipping project without a team: {project['name']} ({project_gid})"
)
return []
if project["privacy_setting"] == "private":
if self.team_gid and project["team"]["gid"] != self.team_gid:
logger.info(
f"Skipping private project not in configured team: {project['name']} ({project_gid})"
)
return []
else:
logger.info(
f"Processing private project in configured team: {project['name']} ({project_gid})"
)
simple_start_date = start_date.split(".")[0].split("+")[0]
logger.info(
f"Fetching tasks modified since {simple_start_date} for project: {project['name']} ({project_gid})"
)
opts = {
"opt_fields": "name,memberships,memberships.project,completed_at,completed_by,created_at,"
"created_by,custom_fields,dependencies,due_at,due_on,external,html_notes,liked,likes,"
"modified_at,notes,num_hearts,parent,projects,resource_subtype,resource_type,start_on,"
"workspace,permalink_url",
"modified_since": start_date,
}
tasks_from_api = self.tasks_api.get_tasks_for_project(project_gid, opts)
for data in tasks_from_api:
self.task_count += 1
if self.task_count % 10 == 0:
end_seconds = time.mktime(datetime.now().timetuple())
runtime_seconds = end_seconds - start_seconds
if runtime_seconds > 0:
logger.info(
f"Processed {self.task_count} tasks in {runtime_seconds:.0f} seconds "
f"({self.task_count / runtime_seconds:.2f} tasks/second)"
)
logger.debug(f"Processing Asana task: {data['name']}")
text = self._construct_task_text(data)
try:
text += self._fetch_and_add_comments(data["gid"])
last_modified_date = self.format_date(data["modified_at"])
text += f"Last modified: {last_modified_date}\n"
task = AsanaTask(
id=data["gid"],
title=data["name"],
text=text,
link=data["permalink_url"],
last_modified=datetime.fromisoformat(data["modified_at"]),
project_gid=project_gid,
project_name=project["name"],
)
yield task
except Exception:
logger.error(
f"Error processing task {data['gid']} in project {project_gid}",
exc_info=True,
)
self.api_error_count += 1
def _construct_task_text(self, data: Dict) -> str:
text = f"{data['name']}\n\n"
if data["notes"]:
text += f"{data['notes']}\n\n"
if data["created_by"] and data["created_by"]["gid"]:
creator = self.get_user(data["created_by"]["gid"])["name"]
created_date = self.format_date(data["created_at"])
text += f"Created by: {creator} on {created_date}\n"
if data["due_on"]:
due_date = self.format_date(data["due_on"])
text += f"Due date: {due_date}\n"
if data["completed_at"]:
completed_date = self.format_date(data["completed_at"])
text += f"Completed on: {completed_date}\n"
text += "\n"
return text
def _fetch_and_add_comments(self, task_gid: str) -> str:
text = ""
stories_opts: Dict[str, str] = {}
story_start = time.time()
stories = self.stories_api.get_stories_for_task(task_gid, stories_opts)
story_count = 0
comment_count = 0
for story in stories:
story_count += 1
if story["resource_subtype"] == "comment_added":
comment = self.stories_api.get_story(
story["gid"], opts={"opt_fields": "text,created_by,created_at"}
)
commenter = self.get_user(comment["created_by"]["gid"])["name"]
text += f"Comment by {commenter}: {comment['text']}\n\n"
comment_count += 1
story_duration = time.time() - story_start
logger.debug(
f"Processed {story_count} stories (including {comment_count} comments) in {story_duration:.2f} seconds"
)
return text
def get_user(self, user_gid: str) -> Dict:
if self._user is not None:
return self._user
self._user = self.users_api.get_user(user_gid, {"opt_fields": "name,email"})
if not self._user:
logger.warning(f"Unable to fetch user information for user_gid: {user_gid}")
return {"name": "Unknown"}
return self._user
def format_date(self, date_str: str) -> str:
date = datetime.fromisoformat(date_str)
return time.strftime("%Y-%m-%d", date.timetuple())
def get_time(self) -> str:
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())

View File

@@ -1,120 +0,0 @@
import datetime
from typing import Any
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.asana import asana_api
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.logger import setup_logger
logger = setup_logger()
class AsanaConnector(LoadConnector, PollConnector):
def __init__(
self,
asana_workspace_id: str,
asana_project_ids: str | None = None,
asana_team_id: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
) -> None:
self.workspace_id = asana_workspace_id
self.project_ids_to_index: list[str] | None = (
asana_project_ids.split(",") if asana_project_ids is not None else None
)
self.asana_team_id = asana_team_id
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
logger.info(
f"AsanaConnector initialized with workspace_id: {asana_workspace_id}"
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.api_token = credentials["asana_api_token_secret"]
self.asana_client = asana_api.AsanaAPI(
api_token=self.api_token,
workspace_gid=self.workspace_id,
team_gid=self.asana_team_id,
)
logger.info("Asana credentials loaded and API client initialized")
return None
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
start_time = datetime.datetime.fromtimestamp(start).isoformat()
logger.info(f"Starting Asana poll from {start_time}")
asana = asana_api.AsanaAPI(
api_token=self.api_token,
workspace_gid=self.workspace_id,
team_gid=self.asana_team_id,
)
docs_batch: list[Document] = []
tasks = asana.get_tasks(self.project_ids_to_index, start_time)
for task in tasks:
doc = self._message_to_doc(task)
docs_batch.append(doc)
if len(docs_batch) >= self.batch_size:
logger.info(f"Yielding batch of {len(docs_batch)} documents")
yield docs_batch
docs_batch = []
if docs_batch:
logger.info(f"Yielding final batch of {len(docs_batch)} documents")
yield docs_batch
logger.info("Asana poll completed")
def load_from_state(self) -> GenerateDocumentsOutput:
logger.notice("Starting full index of all Asana tasks")
return self.poll_source(start=0, end=None)
def _message_to_doc(self, task: asana_api.AsanaTask) -> Document:
logger.debug(f"Converting Asana task {task.id} to Document")
return Document(
id=task.id,
sections=[Section(link=task.link, text=task.text)],
doc_updated_at=task.last_modified,
source=DocumentSource.ASANA,
semantic_identifier=task.title,
metadata={
"group": task.project_gid,
"project": task.project_name,
},
)
if __name__ == "__main__":
import time
import os
logger.notice("Starting Asana connector test")
connector = AsanaConnector(
os.environ["WORKSPACE_ID"],
os.environ["PROJECT_IDS"],
os.environ["TEAM_ID"],
)
connector.load_credentials(
{
"asana_api_token_secret": os.environ["API_TOKEN"],
}
)
logger.info("Loading all documents from Asana")
all_docs = connector.load_from_state()
current = time.time()
one_day_ago = current - 24 * 60 * 60 # 1 day
logger.info("Polling for documents updated in the last 24 hours")
latest_docs = connector.poll_source(one_day_ago, current)
for docs in latest_docs:
for doc in docs:
print(doc.id)
logger.notice("Asana connector test completed")

View File

@@ -56,7 +56,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
Raises ValueError for unsupported bucket types.
"""
logger.debug(
logger.info(
f"Loading credentials for {self.bucket_name} or type {self.bucket_type}"
)
@@ -194,8 +194,8 @@ class BlobStorageConnector(LoadConnector, PollConnector):
try:
text = extract_file_text(
name,
BytesIO(downloaded_file),
file_name=name,
break_on_unprocessable=False,
)
batch.append(
@@ -220,7 +220,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
yield batch
def load_from_state(self) -> GenerateDocumentsOutput:
logger.debug("Loading blob objects")
logger.info("Loading blob objects")
return self._yield_blob_objects(
start=datetime(1970, 1, 1, tzinfo=timezone.utc),
end=datetime.now(timezone.utc),

View File

@@ -1,32 +0,0 @@
import bs4
def build_confluence_document_id(base_url: str, content_url: str) -> str:
"""For confluence, the document id is the page url for a page based document
or the attachment download url for an attachment based document
Args:
base_url (str): The base url of the Confluence instance
content_url (str): The url of the page or attachment download url
Returns:
str: The document id
"""
return f"{base_url}{content_url}"
def get_used_attachments(text: str) -> list[str]:
"""Parse a Confluence html page to generate a list of current
attachment in used
Args:
text (str): The page content
Returns:
list[str]: List of filenames currently in use by the page text
"""
files_in_used = []
soup = bs4.BeautifulSoup(text, "html.parser")
for attachment in soup.findAll("ri:attachment"):
files_in_used.append(attachment.attrs["ri:filename"])
return files_in_used

View File

@@ -7,25 +7,19 @@ from datetime import timezone
from functools import lru_cache
from typing import Any
from typing import cast
from urllib.parse import urlparse
import bs4
from atlassian import Confluence # type:ignore
from requests import HTTPError
from danswer.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.confluence.confluence_utils import (
build_confluence_document_id,
)
from danswer.connectors.confluence.confluence_utils import get_used_attachments
from danswer.connectors.confluence.rate_limit_handler import (
make_confluence_call_handle_rate_limit,
)
@@ -48,12 +42,77 @@ logger = setup_logger()
# 2. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost
NO_PERMISSIONS_TO_VIEW_ATTACHMENTS_ERROR_STR = (
"User not permitted to view attachments on content"
)
NO_PARENT_OR_NO_PERMISSIONS_ERROR_STR = (
"No parent or not permitted to view content with id"
)
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]:
"""Sample
URL w/ page: https://danswer.atlassian.net/wiki/spaces/1234abcd/pages/5678efgh/overview
URL w/o page: https://danswer.atlassian.net/wiki/spaces/ASAM/overview
wiki_base is https://danswer.atlassian.net/wiki
space is 1234abcd
page_id is 5678efgh
"""
parsed_url = urlparse(wiki_url)
wiki_base = (
parsed_url.scheme
+ "://"
+ parsed_url.netloc
+ parsed_url.path.split("/spaces")[0]
)
path_parts = parsed_url.path.split("/")
space = path_parts[3]
page_id = path_parts[5] if len(path_parts) > 5 else ""
return wiki_base, space, page_id
def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, str, str]:
"""Sample
URL w/ page https://danswer.ai/confluence/display/1234abcd/pages/5678efgh/overview
URL w/o page https://danswer.ai/confluence/display/1234abcd/overview
wiki_base is https://danswer.ai/confluence
space is 1234abcd
page_id is 5678efgh
"""
# /display/ is always right before the space and at the end of the base print()
DISPLAY = "/display/"
PAGE = "/pages/"
parsed_url = urlparse(wiki_url)
wiki_base = (
parsed_url.scheme
+ "://"
+ parsed_url.netloc
+ parsed_url.path.split(DISPLAY)[0]
)
space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0]
page_id = ""
if (content := parsed_url.path.split(PAGE)) and len(content) > 1:
page_id = content[1]
return wiki_base, space, page_id
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]:
is_confluence_cloud = (
".atlassian.net/wiki/spaces/" in wiki_url
or ".jira.com/wiki/spaces/" in wiki_url
)
try:
if is_confluence_cloud:
wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url(
wiki_url
)
else:
wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url(
wiki_url
)
except Exception as e:
error_msg = f"Not a valid Confluence Wiki Link, unable to extract wiki base, space, and page id. Exception: {e}"
logger.error(error_msg)
raise ValueError(error_msg)
return wiki_base, space, page_id, is_confluence_cloud
@lru_cache()
@@ -109,6 +168,24 @@ def parse_html_page(text: str, confluence_client: Confluence) -> str:
return format_document_soup(soup)
def get_used_attachments(text: str, confluence_client: Confluence) -> list[str]:
"""Parse a Confluence html page to generate a list of current
attachment in used
Args:
text (str): The page content
confluence_client (Confluence): Confluence client
Returns:
list[str]: List of filename currently in used
"""
files_in_used = []
soup = bs4.BeautifulSoup(text, "html.parser")
for attachment in soup.findAll("ri:attachment"):
files_in_used.append(attachment.attrs["ri:filename"])
return files_in_used
def _comment_dfs(
comments_str: str,
comment_pages: Collection[dict[str, Any]],
@@ -123,38 +200,19 @@ def _comment_dfs(
comments_str += "\nComment:\n" + parse_html_page(
comment_html, confluence_client
)
try:
child_comment_pages = get_page_child_by_type(
comment_page["id"],
type="comment",
start=None,
limit=None,
expand="body.storage.value",
)
comments_str = _comment_dfs(
comments_str, child_comment_pages, confluence_client
)
except HTTPError as e:
# not the cleanest, but I'm not aware of a nicer way to check the error
if NO_PARENT_OR_NO_PERMISSIONS_ERROR_STR not in str(e):
raise
child_comment_pages = get_page_child_by_type(
comment_page["id"],
type="comment",
start=None,
limit=None,
expand="body.storage.value",
)
comments_str = _comment_dfs(
comments_str, child_comment_pages, confluence_client
)
return comments_str
def _datetime_from_string(datetime_string: str) -> datetime:
datetime_object = datetime.fromisoformat(datetime_string)
if datetime_object.tzinfo is None:
# If no timezone info, assume it is UTC
datetime_object = datetime_object.replace(tzinfo=timezone.utc)
else:
# If not in UTC, translate it
datetime_object = datetime_object.astimezone(timezone.utc)
return datetime_object
class RecursiveIndexer:
def __init__(
self,
@@ -284,10 +342,7 @@ class RecursiveIndexer:
class ConfluenceConnector(LoadConnector, PollConnector):
def __init__(
self,
wiki_base: str,
space: str,
is_cloud: bool,
page_id: str = "",
wiki_page_url: str,
index_recursively: bool = True,
batch_size: int = INDEX_BATCH_SIZE,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
@@ -301,15 +356,15 @@ class ConfluenceConnector(LoadConnector, PollConnector):
self.labels_to_skip = set(labels_to_skip)
self.recursive_indexer: RecursiveIndexer | None = None
self.index_recursively = index_recursively
# Remove trailing slash from wiki_base if present
self.wiki_base = wiki_base.rstrip("/")
self.space = space
self.page_id = page_id
self.is_cloud = is_cloud
(
self.wiki_base,
self.space,
self.page_id,
self.is_cloud,
) = extract_confluence_keys_from_url(wiki_page_url)
self.space_level_scan = False
self.confluence_client: Confluence | None = None
if self.page_id is None or self.page_id == "":
@@ -329,6 +384,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
username=username if self.is_cloud else None,
password=access_token if self.is_cloud else None,
token=access_token if not self.is_cloud else None,
cloud=self.is_cloud,
)
return None
@@ -347,7 +403,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
start=start_ind,
limit=batch_size,
status=(
None if CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES else "current"
"current"
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
else None
),
expand="body.storage.value,version",
)
@@ -368,9 +426,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
start=start_ind + i,
limit=1,
status=(
None
if CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES
else "current"
"current"
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
else None
),
expand="body.storage.value,version",
)
@@ -477,255 +535,145 @@ class ConfluenceConnector(LoadConnector, PollConnector):
logger.exception("Ran into exception when fetching labels from Confluence")
return []
@classmethod
def _attachment_to_download_link(
cls, confluence_client: Confluence, attachment: dict[str, Any]
) -> str:
return confluence_client.url + attachment["_links"]["download"]
@classmethod
def _attachment_to_content(
cls,
confluence_client: Confluence,
attachment: dict[str, Any],
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
if attachment["metadata"]["mediaType"] in [
"image/jpeg",
"image/png",
"image/gif",
"image/svg+xml",
"video/mp4",
"video/quicktime",
]:
return None
download_link = cls._attachment_to_download_link(confluence_client, attachment)
attachment_size = attachment["extensions"]["fileSize"]
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to size. "
f"size={attachment_size} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
)
return None
response = confluence_client._session.get(download_link)
if response.status_code != 200:
logger.warning(
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
)
return None
extracted_text = extract_file_text(
io.BytesIO(response.content),
file_name=attachment["title"],
break_on_unprocessable=False,
)
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to char count. "
f"char count={len(extracted_text)} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
)
return None
return extracted_text
def _fetch_attachments(
self, confluence_client: Confluence, page_id: str, files_in_used: list[str]
) -> tuple[str, list[dict[str, Any]]]:
unused_attachments: list = []
) -> str:
get_attachments_from_content = make_confluence_call_handle_rate_limit(
confluence_client.get_attachments_from_content
)
files_attachment_content: list = []
try:
expand = "history.lastUpdated,metadata.labels"
attachments_container = get_attachments_from_content(
page_id, start=0, limit=500, expand=expand
page_id, start=0, limit=500
)
for attachment in attachments_container["results"]:
if attachment["title"] not in files_in_used:
unused_attachments.append(attachment)
if attachment["metadata"]["mediaType"] in [
"image/jpeg",
"image/png",
"image/gif",
"image/svg+xml",
"video/mp4",
"video/quicktime",
]:
continue
attachment_content = self._attachment_to_content(
confluence_client, attachment
)
if attachment_content:
files_attachment_content.append(attachment_content)
if attachment["title"] not in files_in_used:
continue
download_link = confluence_client.url + attachment["_links"]["download"]
attachment_size = attachment["extensions"]["fileSize"]
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to size. "
f"size={attachment_size} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
)
continue
download_link = confluence_client.url + attachment["_links"]["download"]
response = confluence_client._session.get(download_link)
if response.status_code == 200:
extract = extract_file_text(
attachment["title"], io.BytesIO(response.content), False
)
files_attachment_content.append(extract)
except Exception as e:
if isinstance(
e, HTTPError
) and NO_PERMISSIONS_TO_VIEW_ATTACHMENTS_ERROR_STR in str(e):
logger.warning(
f"User does not have access to attachments on page '{page_id}'"
)
return "", []
if not self.continue_on_failure:
raise e
logger.exception(
f"Ran into exception when fetching attachments from Confluence: {e}"
)
return "\n".join(files_attachment_content), unused_attachments
return "\n".join(files_attachment_content)
def _get_doc_batch(
self, start_ind: int, time_filter: Callable[[datetime], bool] | None = None
) -> tuple[list[Document], list[dict[str, Any]], int]:
) -> tuple[list[Document], int]:
doc_batch: list[Document] = []
unused_attachments: list[dict[str, Any]] = []
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
batch = self._fetch_pages(self.confluence_client, start_ind)
for page in batch:
last_modified = _datetime_from_string(page["version"]["when"])
last_modified_str = page["version"]["when"]
author = cast(str | None, page["version"].get("by", {}).get("email"))
last_modified = datetime.fromisoformat(last_modified_str)
if time_filter and not time_filter(last_modified):
continue
if last_modified.tzinfo is None:
# If no timezone info, assume it is UTC
last_modified = last_modified.replace(tzinfo=timezone.utc)
else:
# If not in UTC, translate it
last_modified = last_modified.astimezone(timezone.utc)
page_id = page["id"]
if time_filter is None or time_filter(last_modified):
page_id = page["id"]
if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
page_labels = self._fetch_labels(self.confluence_client, page_id)
if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
page_labels = self._fetch_labels(self.confluence_client, page_id)
# check disallowed labels
if self.labels_to_skip:
label_intersection = self.labels_to_skip.intersection(page_labels)
if label_intersection:
logger.info(
f"Page with ID '{page_id}' has a label which has been "
f"designated as disallowed: {label_intersection}. Skipping."
)
# check disallowed labels
if self.labels_to_skip:
label_intersection = self.labels_to_skip.intersection(page_labels)
if label_intersection:
logger.info(
f"Page with ID '{page_id}' has a label which has been "
f"designated as disallowed: {label_intersection}. Skipping."
)
continue
page_html = (
page["body"]
.get("storage", page["body"].get("view", {}))
.get("value")
)
page_url = self.wiki_base + page["_links"]["webui"]
if not page_html:
logger.debug("Page is empty, skipping: %s", page_url)
continue
page_text = parse_html_page(page_html, self.confluence_client)
page_html = (
page["body"].get("storage", page["body"].get("view", {})).get("value")
)
# The url and the id are the same
page_url = build_confluence_document_id(
self.wiki_base, page["_links"]["webui"]
)
if not page_html:
logger.debug("Page is empty, skipping: %s", page_url)
continue
page_text = parse_html_page(page_html, self.confluence_client)
files_in_used = get_used_attachments(page_html)
attachment_text, unused_page_attachments = self._fetch_attachments(
self.confluence_client, page_id, files_in_used
)
unused_attachments.extend(unused_page_attachments)
page_text += "\n" + attachment_text if attachment_text else ""
comments_text = self._fetch_comments(self.confluence_client, page_id)
page_text += comments_text
doc_metadata: dict[str, str | list[str]] = {"Wiki Space Name": self.space}
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING and page_labels:
doc_metadata["labels"] = page_labels
doc_batch.append(
Document(
id=page_url,
sections=[Section(link=page_url, text=page_text)],
source=DocumentSource.CONFLUENCE,
semantic_identifier=page["title"],
doc_updated_at=last_modified,
primary_owners=(
[BasicExpertInfo(email=author)] if author else None
),
metadata=doc_metadata,
files_in_used = get_used_attachments(page_html, self.confluence_client)
attachment_text = self._fetch_attachments(
self.confluence_client, page_id, files_in_used
)
)
return (
doc_batch,
unused_attachments,
len(batch),
)
page_text += attachment_text
comments_text = self._fetch_comments(self.confluence_client, page_id)
page_text += comments_text
doc_metadata: dict[str, str | list[str]] = {
"Wiki Space Name": self.space
}
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING and page_labels:
doc_metadata["labels"] = page_labels
def _get_attachment_batch(
self,
start_ind: int,
attachments: list[dict[str, Any]],
time_filter: Callable[[datetime], bool] | None = None,
) -> tuple[list[Document], int]:
doc_batch: list[Document] = []
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
end_ind = min(start_ind + self.batch_size, len(attachments))
for attachment in attachments[start_ind:end_ind]:
last_updated = _datetime_from_string(
attachment["history"]["lastUpdated"]["when"]
)
if time_filter and not time_filter(last_updated):
continue
# The url and the id are the same
attachment_url = build_confluence_document_id(
self.wiki_base, attachment["_links"]["download"]
)
attachment_content = self._attachment_to_content(
self.confluence_client, attachment
)
if attachment_content is None:
continue
creator_email = attachment["history"]["createdBy"].get("email")
comment = attachment["metadata"].get("comment", "")
doc_metadata: dict[str, str | list[str]] = {"comment": comment}
attachment_labels: list[str] = []
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
for label in attachment["metadata"]["labels"]["results"]:
attachment_labels.append(label["name"])
doc_metadata["labels"] = attachment_labels
doc_batch.append(
Document(
id=attachment_url,
sections=[Section(link=attachment_url, text=attachment_content)],
source=DocumentSource.CONFLUENCE,
semantic_identifier=attachment["title"],
doc_updated_at=last_updated,
primary_owners=(
[BasicExpertInfo(email=creator_email)]
if creator_email
else None
),
metadata=doc_metadata,
doc_batch.append(
Document(
id=page_url,
sections=[Section(link=page_url, text=page_text)],
source=DocumentSource.CONFLUENCE,
semantic_identifier=page["title"],
doc_updated_at=last_modified,
primary_owners=(
[BasicExpertInfo(email=author)] if author else None
),
metadata=doc_metadata,
)
)
)
return doc_batch, end_ind - start_ind
return doc_batch, len(batch)
def load_from_state(self) -> GenerateDocumentsOutput:
unused_attachments = []
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
start_ind = 0
while True:
doc_batch, unused_attachments_batch, num_pages = self._get_doc_batch(
start_ind
)
unused_attachments.extend(unused_attachments_batch)
doc_batch, num_pages = self._get_doc_batch(start_ind)
start_ind += num_pages
if doc_batch:
yield doc_batch
@@ -733,23 +681,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
if num_pages < self.batch_size:
break
start_ind = 0
while True:
attachment_batch, num_attachments = self._get_attachment_batch(
start_ind, unused_attachments
)
start_ind += num_attachments
if attachment_batch:
yield attachment_batch
if num_attachments < self.batch_size:
break
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
unused_attachments = []
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
@@ -758,11 +692,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
start_ind = 0
while True:
doc_batch, unused_attachments_batch, num_pages = self._get_doc_batch(
doc_batch, num_pages = self._get_doc_batch(
start_ind, time_filter=lambda t: start_time <= t <= end_time
)
unused_attachments.extend(unused_attachments_batch)
start_ind += num_pages
if doc_batch:
yield doc_batch
@@ -770,29 +702,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
if num_pages < self.batch_size:
break
start_ind = 0
while True:
attachment_batch, num_attachments = self._get_attachment_batch(
start_ind,
unused_attachments,
time_filter=lambda t: start_time <= t <= end_time,
)
start_ind += num_attachments
if attachment_batch:
yield attachment_batch
if num_attachments < self.batch_size:
break
if __name__ == "__main__":
connector = ConfluenceConnector(
wiki_base=os.environ["CONFLUENCE_TEST_SPACE_URL"],
space=os.environ["CONFLUENCE_TEST_SPACE"],
is_cloud=os.environ.get("CONFLUENCE_IS_CLOUD", "true").lower() == "true",
page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""),
index_recursively=True,
)
connector = ConfluenceConnector(os.environ["CONFLUENCE_TEST_SPACE_URL"])
connector.load_credentials(
{
"confluence_username": os.environ["CONFLUENCE_USER_NAME"],

View File

@@ -23,39 +23,25 @@ class ConfluenceRateLimitError(Exception):
def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
max_retries = 5
starting_delay = 5
backoff = 2
max_delay = 600
for attempt in range(max_retries):
for attempt in range(10):
try:
return confluence_call(*args, **kwargs)
except HTTPError as e:
# Check if the response or headers are None to avoid potential AttributeError
if e.response is None or e.response.headers is None:
logger.warning("HTTPError with `None` as response or as headers")
raise e
retry_after_header = e.response.headers.get("Retry-After")
if (
e.response.status_code == 429
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
):
retry_after = None
if retry_after_header is not None:
try:
retry_after = int(retry_after_header)
except ValueError:
pass
if retry_after is not None:
if retry_after > 600:
logger.warning(
f"Clamping retry_after from {retry_after} to {max_delay} seconds..."
)
retry_after = max_delay
try:
retry_after = int(e.response.headers.get("Retry-After"))
except (ValueError, TypeError):
pass
if retry_after:
logger.warning(
f"Rate limit hit. Retrying after {retry_after} seconds..."
)
@@ -69,14 +55,5 @@ def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
else:
# re-raise, let caller handle
raise
except AttributeError as e:
# Some error within the Confluence library, unclear why it fails.
# Users reported it to be intermittent, so just retry
logger.warning(f"Confluence Internal Error, retrying... {e}")
delay = min(starting_delay * (backoff**attempt), max_delay)
time.sleep(delay)
if attempt == max_retries - 1:
raise e
return cast(F, wrapped_call)

View File

@@ -1,70 +0,0 @@
import sys
from datetime import datetime
from danswer.connectors.interfaces import BaseConnector
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.utils.logger import setup_logger
logger = setup_logger()
TimeRange = tuple[datetime, datetime]
class ConnectorRunner:
def __init__(
self,
connector: BaseConnector,
time_range: TimeRange | None = None,
fail_loudly: bool = False,
):
self.connector = connector
if isinstance(self.connector, PollConnector):
if time_range is None:
raise ValueError("time_range is required for PollConnector")
self.doc_batch_generator = self.connector.poll_source(
time_range[0].timestamp(), time_range[1].timestamp()
)
elif isinstance(self.connector, LoadConnector):
if time_range and fail_loudly:
raise ValueError(
"time_range specified, but passed in connector is not a PollConnector"
)
self.doc_batch_generator = self.connector.load_from_state()
else:
raise ValueError(f"Invalid connector. type: {type(self.connector)}")
def run(self) -> GenerateDocumentsOutput:
"""Adds additional exception logging to the connector."""
try:
yield from self.doc_batch_generator
except Exception:
exc_type, _, exc_traceback = sys.exc_info()
# Traverse the traceback to find the last frame where the exception was raised
tb = exc_traceback
if tb is None:
logger.error("No traceback found for exception")
raise
while tb.tb_next:
tb = tb.tb_next # Move to the next frame in the traceback
# Get the local variables from the frame where the exception occurred
local_vars = tb.tb_frame.f_locals
local_vars_str = "\n".join(
f"{key}: {value}" for key, value in local_vars.items()
)
logger.error(
f"Error in connector. type: {exc_type};\n"
f"local_vars below -> \n{local_vars_str}"
)
raise

View File

@@ -56,7 +56,7 @@ class _RateLimitDecorator:
sleep_cnt = 0
while len(self.call_history) == self.max_calls:
sleep_time = self.sleep_time * (self.sleep_backoff**sleep_cnt)
logger.notice(
logger.info(
f"Rate limit exceeded for function {func.__name__}. "
f"Waiting {sleep_time} seconds before retrying."
)

View File

@@ -9,7 +9,6 @@ from jira.resources import Issue
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
@@ -46,15 +45,10 @@ def extract_jira_project(url: str) -> tuple[str, str]:
return jira_base, jira_project
def extract_text_from_adf(adf: dict | None) -> str:
"""Extracts plain text from Atlassian Document Format:
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
WARNING: This function is incomplete and will e.g. skip lists!
"""
def extract_text_from_content(content: dict) -> str:
texts = []
if adf is not None and "content" in adf:
for block in adf["content"]:
if "content" in content:
for block in content["content"]:
if "content" in block:
for item in block["content"]:
if item["type"] == "text":
@@ -78,15 +72,18 @@ def _get_comment_strs(
comment_strs = []
for comment in jira.fields.comment.comments:
try:
body_text = (
comment.body
if JIRA_API_VERSION == "2"
else extract_text_from_adf(comment.raw["body"])
)
if hasattr(comment, "body"):
body_text = extract_text_from_content(comment.raw["body"])
elif hasattr(comment, "raw"):
body = comment.raw.get("body", "No body content available")
body_text = (
extract_text_from_content(body) if isinstance(body, dict) else body
)
else:
body_text = "No body attribute found"
if (
hasattr(comment, "author")
and hasattr(comment.author, "emailAddress")
and comment.author.emailAddress in comment_email_blacklist
):
continue # Skip adding comment if author's email is in blacklist
@@ -129,24 +126,13 @@ def fetch_jira_issues_batch(
)
continue
description = (
jira.fields.description
if JIRA_API_VERSION == "2"
else extract_text_from_adf(jira.raw["fields"]["description"])
)
comments = _get_comment_strs(jira, comment_email_blacklist)
ticket_content = f"{description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments if comment]
semantic_rep = (
f"{jira.fields.description}\n"
if jira.fields.description
else "" + "\n".join([f"Comment: {comment}" for comment in comments])
)
# Check ticket size
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
logger.info(
f"Skipping {jira.key} because it exceeds the maximum size of "
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
)
continue
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
people = set()
@@ -189,7 +175,7 @@ def fetch_jira_issues_batch(
doc_batch.append(
Document(
id=page_url,
sections=[Section(link=page_url, text=ticket_content)],
sections=[Section(link=page_url, text=semantic_rep)],
source=DocumentSource.JIRA,
semantic_identifier=jira.fields.summary,
doc_updated_at=time_str_to_utc(jira.fields.updated),
@@ -245,12 +231,10 @@ class JiraConnector(LoadConnector, PollConnector):
if self.jira_client is None:
raise ConnectorMissingCredentialError("Jira")
# Quote the project name to handle reserved words
quoted_project = f'"{self.jira_project}"'
start_ind = 0
while True:
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
jql=f"project = {quoted_project}",
jql=f"project = {self.jira_project}",
start_index=start_ind,
jira_client=self.jira_client,
batch_size=self.batch_size,
@@ -278,10 +262,8 @@ class JiraConnector(LoadConnector, PollConnector):
"%Y-%m-%d %H:%M"
)
# Quote the project name to handle reserved words
quoted_project = f'"{self.jira_project}"'
jql = (
f"project = {quoted_project} AND "
f"project = {self.jira_project} AND "
f"updated >= '{start_date_str}' AND "
f"updated <= '{end_date_str}'"
)

View File

@@ -97,8 +97,8 @@ class DropboxConnector(LoadConnector, PollConnector):
link = self._get_shared_link(entry.path_display)
try:
text = extract_file_text(
entry.name,
BytesIO(downloaded_file),
file_name=entry.name,
break_on_unprocessable=False,
)
batch.append(

View File

@@ -4,7 +4,6 @@ from typing import Type
from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource
from danswer.connectors.asana.connector import AsanaConnector
from danswer.connectors.axero.connector import AxeroConnector
from danswer.connectors.blob.connector import BlobStorageConnector
from danswer.connectors.bookstack.connector import BookstackConnector
@@ -42,7 +41,6 @@ from danswer.connectors.slack.load_connector import SlackLoadConnector
from danswer.connectors.teams.connector import TeamsConnector
from danswer.connectors.web.connector import WebConnector
from danswer.connectors.wikipedia.connector import WikipediaConnector
from danswer.connectors.xenforo.connector import XenforoConnector
from danswer.connectors.zendesk.connector import ZendeskConnector
from danswer.connectors.zulip.connector import ZulipConnector
from danswer.db.credentials import backend_update_credential_json
@@ -63,7 +61,6 @@ def identify_connector_class(
DocumentSource.SLACK: {
InputType.LOAD_STATE: SlackLoadConnector,
InputType.POLL: SlackPollConnector,
InputType.PRUNE: SlackPollConnector,
},
DocumentSource.GITHUB: GithubConnector,
DocumentSource.GMAIL: GmailConnector,
@@ -94,12 +91,10 @@ def identify_connector_class(
DocumentSource.CLICKUP: ClickupConnector,
DocumentSource.MEDIAWIKI: MediaWikiConnector,
DocumentSource.WIKIPEDIA: WikipediaConnector,
DocumentSource.ASANA: AsanaConnector,
DocumentSource.S3: BlobStorageConnector,
DocumentSource.R2: BlobStorageConnector,
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
DocumentSource.OCI_STORAGE: BlobStorageConnector,
DocumentSource.XENFORO: XenforoConnector,
}
connector_by_source = connector_map.get(source, {})
@@ -129,11 +124,11 @@ def identify_connector_class(
def instantiate_connector(
db_session: Session,
source: DocumentSource,
input_type: InputType,
connector_specific_config: dict[str, Any],
credential: Credential,
db_session: Session,
) -> BaseConnector:
connector_class = identify_connector_class(source, input_type)
connector = connector_class(**connector_specific_config)

Some files were not shown because too many files have changed in this diff Show More