Compare commits

..

50 Commits

Author SHA1 Message Date
pablodanswer
4eb53ce56f rebase needs fixing 2024-08-19 07:40:53 -07:00
pablodanswer
2fc84ed63e post rebase fix 2024-08-18 16:41:12 -07:00
pablodanswer
722d5e6e54 add sequential tool calls 2024-08-18 16:40:07 -07:00
pablodanswer
14c30d2e4d add env variable 2024-08-18 15:05:44 -07:00
pablodanswer
6abad2fdd3 robust chat session state persistence 2024-08-18 15:05:44 -07:00
pablodanswer
4691e736f6 functional new message carry-over 2024-08-18 15:05:44 -07:00
pablodanswer
5a826a527f properly reset blank screen 2024-08-18 15:05:44 -07:00
pablodanswer
f92d31df70 refactored for stop / regenerate 2024-08-18 15:05:44 -07:00
pablodanswer
1eb786897a proper margin 2024-08-18 15:05:26 -07:00
pablodanswer
72471f9e1d remove parameter 2024-08-18 15:05:26 -07:00
pablodanswer
49c335d06a squash 2024-08-18 15:05:26 -07:00
pablodanswer
fda06b7739 more robust implementation for first messages 2024-08-18 15:05:26 -07:00
pablodanswer
00d44e31b3 validated + cleaner UI 2024-08-18 15:05:26 -07:00
pablodanswer
2a42c1dd18 functional once again post rebase but quite ugly 2024-08-18 15:05:26 -07:00
pablodanswer
05cd25043e add regenerate 2024-08-18 15:05:26 -07:00
pablodanswer
abebff50bb Enable seeding of analytics via file path (#2146)
* enable seeding of analytics via file path

* remove log
2024-08-18 15:05:26 -07:00
pablodanswer
0a7e672832 add handling for poorly formatting model names (#2143) 2024-08-18 15:05:26 -07:00
pablodanswer
221ab9134c add critical error just in case 2024-08-18 15:03:04 -07:00
pablodanswer
f7134202b6 slightly more specific logs 2024-08-18 14:44:10 -07:00
pablodanswer
bea11dc3aa include logs 2024-08-18 14:33:45 -07:00
pablodanswer
374b798071 update typing 2024-08-17 13:51:52 -07:00
pablodanswer
6a2e3edfcd add synchronous wrapper to avoid hampering main event loop 2024-08-17 13:39:22 -07:00
pablodanswer
2ef1731e32 tiny formatting (remove newline) 2024-08-17 09:29:39 -07:00
pablodanswer
7d4d7a5f5d clean final message handling 2024-08-17 01:14:31 -07:00
pablodanswer
ea2f9cf625 cleaner messages 2024-08-15 17:17:03 -07:00
pablodanswer
97dc9c5e31 add back stack trace detail 2024-08-15 16:46:32 -07:00
pablodanswer
249bcd46d9 clearer 2024-08-15 16:10:56 -07:00
pablodanswer
f29b727bc7 remove comments 2024-08-15 16:10:56 -07:00
pablodanswer
31fb6c0753 improve clarity + new SSE handling utility function 2024-08-15 16:10:56 -07:00
pablodanswer
a45e72c298 update utility + copy 2024-08-15 16:10:56 -07:00
pablodanswer
157548817c slightly more robust chat state 2024-08-15 16:10:56 -07:00
pablodanswer
d9396f77d1 remove false comment 2024-08-15 16:10:56 -07:00
pablodanswer
7bae6bbf8f remove log 2024-08-15 16:10:56 -07:00
pablodanswer
1d535769ed robustify 2024-08-15 16:10:56 -07:00
pablodanswer
8584a81fe2 unnecessary list removed 2024-08-15 16:10:56 -07:00
pablodanswer
5f4ac19928 robustify typing 2024-08-15 16:10:56 -07:00
pablodanswer
d898e4f738 remove logs 2024-08-15 16:10:56 -07:00
pablodanswer
19412f0aa0 add ChatState for more robust handling 2024-08-15 16:10:56 -07:00
pablodanswer
c338de30fd add new loading state to prevent collisions 2024-08-15 16:10:56 -07:00
pablodanswer
edfde621b9 formatting 2024-08-15 16:10:56 -07:00
pablodanswer
9306abf911 migrate to streaming response 2024-08-15 16:10:56 -07:00
pablodanswer
70d885b621 cleaner loop + data persistence 2024-08-15 16:10:56 -07:00
pablodanswer
53bea4f859 robustify frontend handling 2024-08-15 16:10:55 -07:00
pablodanswer
a79d734d96 typing 2024-08-15 16:10:28 -07:00
pablodanswer
25cd7de147 remove logs 2024-08-15 16:10:28 -07:00
pablodanswer
ab2916c807 robustify switching 2024-08-15 16:10:28 -07:00
pablodanswer
96112f1f95 functional rework of temporary user/assistant ID 2024-08-15 16:10:28 -07:00
pablodanswer
54502b32d3 remove logs 2024-08-15 16:10:28 -07:00
pablodanswer
9431e6c06c remove commits 2024-08-15 16:10:28 -07:00
pablodanswer
f18571d580 functional types + sidebar 2024-08-15 16:10:28 -07:00
909 changed files with 19128 additions and 82243 deletions

View File

@@ -1,109 +0,0 @@
name: 'Build and Push Docker Image with Retry'
description: 'Attempts to build and push a Docker image, with a retry on failure'
inputs:
context:
description: 'Build context'
required: true
file:
description: 'Dockerfile location'
required: true
platforms:
description: 'Target platforms'
required: true
pull:
description: 'Always attempt to pull a newer version of the image'
required: false
default: 'true'
push:
description: 'Push the image to registry'
required: false
default: 'true'
load:
description: 'Load the image into Docker daemon'
required: false
default: 'true'
tags:
description: 'Image tags'
required: true
cache-from:
description: 'Cache sources'
required: false
cache-to:
description: 'Cache destinations'
required: false
retry-wait-time:
description: 'Time to wait before attempt 2 in seconds'
required: false
default: '60'
retry-wait-time-2:
description: 'Time to wait before attempt 3 in seconds'
required: false
default: '120'
runs:
using: "composite"
steps:
- name: Build and push Docker image (Attempt 1 of 3)
id: buildx1
uses: docker/build-push-action@v6
continue-on-error: true
with:
context: ${{ inputs.context }}
file: ${{ inputs.file }}
platforms: ${{ inputs.platforms }}
pull: ${{ inputs.pull }}
push: ${{ inputs.push }}
load: ${{ inputs.load }}
tags: ${{ inputs.tags }}
cache-from: ${{ inputs.cache-from }}
cache-to: ${{ inputs.cache-to }}
- name: Wait before attempt 2
if: steps.buildx1.outcome != 'success'
run: |
echo "First attempt failed. Waiting ${{ inputs.retry-wait-time }} seconds before retry..."
sleep ${{ inputs.retry-wait-time }}
shell: bash
- name: Build and push Docker image (Attempt 2 of 3)
id: buildx2
if: steps.buildx1.outcome != 'success'
uses: docker/build-push-action@v6
with:
context: ${{ inputs.context }}
file: ${{ inputs.file }}
platforms: ${{ inputs.platforms }}
pull: ${{ inputs.pull }}
push: ${{ inputs.push }}
load: ${{ inputs.load }}
tags: ${{ inputs.tags }}
cache-from: ${{ inputs.cache-from }}
cache-to: ${{ inputs.cache-to }}
- name: Wait before attempt 3
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success'
run: |
echo "Second attempt failed. Waiting ${{ inputs.retry-wait-time-2 }} seconds before retry..."
sleep ${{ inputs.retry-wait-time-2 }}
shell: bash
- name: Build and push Docker image (Attempt 3 of 3)
id: buildx3
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success'
uses: docker/build-push-action@v6
with:
context: ${{ inputs.context }}
file: ${{ inputs.file }}
platforms: ${{ inputs.platforms }}
pull: ${{ inputs.pull }}
push: ${{ inputs.push }}
load: ${{ inputs.load }}
tags: ${{ inputs.tags }}
cache-from: ${{ inputs.cache-from }}
cache-to: ${{ inputs.cache-to }}
- name: Report failure
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success' && steps.buildx3.outcome != 'success'
run: |
echo "All attempts failed. Possible transient infrastucture issues? Try again later or inspect logs for details."
shell: bash

View File

@@ -0,0 +1,33 @@
name: Build Backend Image on Merge Group
on:
merge_group:
types: [checks_requested]
env:
REGISTRY_IMAGE: danswer/danswer-backend
jobs:
build:
# TODO: make this a matrix build like the web containers
runs-on:
group: amd64-image-builders
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Backend Image Docker Build
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64,linux/arm64
push: false
tags: |
${{ env.REGISTRY_IMAGE }}:latest
build-args: |
DANSWER_VERSION=v0.0.1

View File

@@ -7,17 +7,16 @@ on:
env:
REGISTRY_IMAGE: danswer/danswer-backend
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
# TODO: investigate a matrix build like the web container
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
# TODO: make this a matrix build like the web containers
runs-on:
group: amd64-image-builders
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -28,11 +27,6 @@ jobs:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Install build-essential
run: |
sudo apt-get update
sudo apt-get install -y build-essential
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v5
with:
@@ -42,20 +36,12 @@ jobs:
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
${{ env.REGISTRY_IMAGE }}:latest
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}

View File

@@ -5,18 +5,14 @@ on:
tags:
- '*'
env:
REGISTRY_IMAGE: danswer/danswer-model-server
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
runs-on:
group: amd64-image-builders
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -35,21 +31,13 @@ jobs:
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
danswer/danswer-model-server:${{ github.ref_name }}
danswer/danswer-model-server:latest
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'

View File

@@ -7,15 +7,11 @@ on:
env:
REGISTRY_IMAGE: danswer/danswer-web-server
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build:
runs-on:
- runs-on
- runner=${{ matrix.platform == 'linux/amd64' && '8cpu-linux-x64' || '8cpu-linux-arm64' }}
- run-id=${{ github.run_id }}
- tag=platform-${{ matrix.platform }}
runs-on:
group: ${{ matrix.platform == 'linux/amd64' && 'amd64-image-builders' || 'arm64-image-builders' }}
strategy:
fail-fast: false
matrix:
@@ -39,7 +35,7 @@ jobs:
images: ${{ env.REGISTRY_IMAGE }}
tags: |
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
type=raw,value=${{ env.REGISTRY_IMAGE }}:latest
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -116,16 +112,8 @@ jobs:
run: |
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'

View File

@@ -0,0 +1,53 @@
name: Build Web Image on Merge Group
on:
merge_group:
types: [checks_requested]
env:
REGISTRY_IMAGE: danswer/danswer-web-server
jobs:
build:
runs-on:
group: ${{ matrix.platform == 'linux/amd64' && 'amd64-image-builders' || 'arm64-image-builders' }}
strategy:
fail-fast: false
matrix:
platform:
- linux/amd64
- linux/arm64
steps:
- name: Prepare
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
tags: |
type=raw,value=${{ env.REGISTRY_IMAGE }}:latest
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build by digest
id: build
uses: docker/build-push-action@v5
with:
context: ./web
file: ./web/Dockerfile
platforms: ${{ matrix.platform }}
push: false
build-args: |
DANSWER_VERSION=v0.0.1
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -1,6 +1,3 @@
# This workflow is set up to be manually triggered via the GitHub Action tab.
# Given a version, it will tag those backend and webserver images as "latest".
name: Tag Latest Version
on:
@@ -12,9 +9,7 @@ on:
jobs:
tag:
# See https://runs-on.com/runners/linux/
# use a lower powered instance since this just does i/o to docker hub
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
runs-on: ubuntu-latest
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1

View File

@@ -1,172 +0,0 @@
# This workflow is intended to be manually triggered via the GitHub Action tab.
# Given a hotfix branch, it will attempt to open a PR to all release branches and
# by default auto merge them
name: Hotfix release branches
on:
workflow_dispatch:
inputs:
hotfix_commit:
description: 'Hotfix commit hash'
required: true
hotfix_suffix:
description: 'Hotfix branch suffix (e.g. hotfix/v0.8-{suffix})'
required: true
release_branch_pattern:
description: 'Release branch pattern (regex)'
required: true
default: 'release/.*'
auto_merge:
description: 'Automatically merge the hotfix PRs'
required: true
type: choice
default: 'true'
options:
- true
- false
jobs:
hotfix_release_branches:
permissions: write-all
# See https://runs-on.com/runners/linux/
# use a lower powered instance since this just does i/o to docker hub
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
# needs RKUO_DEPLOY_KEY for write access to merge PR's
- name: Checkout Repository
uses: actions/checkout@v4
with:
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
fetch-depth: 0
- name: Set up Git user
run: |
git config user.name "Richard Kuo [bot]"
git config user.email "rkuo[bot]@danswer.ai"
- name: Fetch All Branches
run: |
git fetch --all --prune
- name: Verify Hotfix Commit Exists
run: |
git rev-parse --verify "${{ github.event.inputs.hotfix_commit }}" || { echo "Commit not found: ${{ github.event.inputs.hotfix_commit }}"; exit 1; }
- name: Get Release Branches
id: get_release_branches
run: |
BRANCHES=$(git branch -r | grep -E "${{ github.event.inputs.release_branch_pattern }}" | sed 's|origin/||' | tr -d ' ')
if [ -z "$BRANCHES" ]; then
echo "No release branches found matching pattern '${{ github.event.inputs.release_branch_pattern }}'."
exit 1
fi
echo "Found release branches:"
echo "$BRANCHES"
# Join the branches into a single line separated by commas
BRANCHES_JOINED=$(echo "$BRANCHES" | tr '\n' ',' | sed 's/,$//')
# Set the branches as an output
echo "branches=$BRANCHES_JOINED" >> $GITHUB_OUTPUT
# notes on all the vagaries of wiring up automated PR's
# https://github.com/peter-evans/create-pull-request/blob/main/docs/concepts-guidelines.md#triggering-further-workflow-runs
# we must use a custom token for GH_TOKEN to trigger the subsequent PR checks
- name: Create and Merge Pull Requests to Matching Release Branches
env:
HOTFIX_COMMIT: ${{ github.event.inputs.hotfix_commit }}
HOTFIX_SUFFIX: ${{ github.event.inputs.hotfix_suffix }}
AUTO_MERGE: ${{ github.event.inputs.auto_merge }}
GH_TOKEN: ${{ secrets.RKUO_PERSONAL_ACCESS_TOKEN }}
run: |
# Get the branches from the previous step
BRANCHES="${{ steps.get_release_branches.outputs.branches }}"
# Convert BRANCHES to an array
IFS=$',' read -ra BRANCH_ARRAY <<< "$BRANCHES"
# Loop through each release branch and create and merge a PR
for RELEASE_BRANCH in "${BRANCH_ARRAY[@]}"; do
echo "Processing $RELEASE_BRANCH..."
# Parse out the release version by removing "release/" from the branch name
RELEASE_VERSION=${RELEASE_BRANCH#release/}
echo "Release version parsed: $RELEASE_VERSION"
HOTFIX_BRANCH="hotfix/${RELEASE_VERSION}-${HOTFIX_SUFFIX}"
echo "Creating PR from $HOTFIX_BRANCH to $RELEASE_BRANCH"
# Checkout the release branch
echo "Checking out $RELEASE_BRANCH"
git checkout "$RELEASE_BRANCH"
# Create the new hotfix branch
if git rev-parse --verify "$HOTFIX_BRANCH" >/dev/null 2>&1; then
echo "Hotfix branch $HOTFIX_BRANCH already exists. Skipping branch creation."
else
echo "Branching $RELEASE_BRANCH to $HOTFIX_BRANCH"
git checkout -b "$HOTFIX_BRANCH"
fi
# Check if the hotfix commit is a merge commit
if git rev-list --merges -n 1 "$HOTFIX_COMMIT" >/dev/null 2>&1; then
# -m 1 uses the target branch as the base (which is what we want)
echo "Hotfix commit $HOTFIX_COMMIT is a merge commit, using -m 1 for cherry-pick"
CHERRY_PICK_CMD="git cherry-pick -m 1 $HOTFIX_COMMIT"
else
CHERRY_PICK_CMD="git cherry-pick $HOTFIX_COMMIT"
fi
# Perform the cherry-pick
echo "Executing: $CHERRY_PICK_CMD"
eval "$CHERRY_PICK_CMD"
if [ $? -ne 0 ]; then
echo "Cherry-pick failed for $HOTFIX_COMMIT on $HOTFIX_BRANCH. Aborting..."
git cherry-pick --abort
continue
fi
# Push the hotfix branch to the remote
echo "Pushing $HOTFIX_BRANCH..."
git push origin "$HOTFIX_BRANCH"
echo "Hotfix branch $HOTFIX_BRANCH created and pushed."
# Check if PR already exists
EXISTING_PR=$(gh pr list --head "$HOTFIX_BRANCH" --base "$RELEASE_BRANCH" --state open --json number --jq '.[0].number')
if [ -n "$EXISTING_PR" ]; then
echo "An open PR already exists: #$EXISTING_PR. Skipping..."
continue
fi
# Create a new PR and capture the output
PR_OUTPUT=$(gh pr create --title "Merge $HOTFIX_BRANCH into $RELEASE_BRANCH" \
--body "Automated PR to merge \`$HOTFIX_BRANCH\` into \`$RELEASE_BRANCH\`." \
--head "$HOTFIX_BRANCH" --base "$RELEASE_BRANCH")
# Extract the URL from the output
PR_URL=$(echo "$PR_OUTPUT" | grep -Eo 'https://github.com/[^ ]+')
echo "Pull request created: $PR_URL"
# Extract PR number from URL
PR_NUMBER=$(basename "$PR_URL")
echo "Pull request created: $PR_NUMBER"
if [ "$AUTO_MERGE" == "true" ]; then
echo "Attempting to merge pull request #$PR_NUMBER"
# Attempt to merge the PR
gh pr merge "$PR_NUMBER" --merge --auto --delete-branch
if [ $? -eq 0 ]; then
echo "Pull request #$PR_NUMBER merged successfully."
else
# Optionally, handle the error or continue
echo "Failed to merge pull request #$PR_NUMBER."
fi
fi
done

View File

@@ -1,178 +0,0 @@
name: Run Integration Tests v2
concurrency:
group: Run-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
merge_group:
pull_request:
branches:
- main
- 'release/**'
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
jobs:
integration-tests:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# We don't need to build the Web Docker image since it's not yet used
# in the integration tests. We have a separate action to verify that it builds
# successfully.
- name: Pull Web Docker image
run: |
docker pull danswer/danswer-web-server:latest
docker tag danswer/danswer-web-server:latest danswer/danswer-web-server:test
# we use the runs-on cache for docker builds
# in conjunction with runs-on runners, it has better speed and unlimited caching
# https://runs-on.com/caching/s3-cache-for-github-actions/
# https://runs-on.com/caching/docker/
# https://github.com/moby/buildkit#s3-cache-experimental
# images are built and run locally for testing purposes. Not pushed.
- name: Build Backend Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64
tags: danswer/danswer-backend:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Model Server Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64
tags: danswer/danswer-model-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build integration test Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/tests/integration/Dockerfile
platforms: linux/amd64
tags: danswer/danswer-integration:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
docker logs -f danswer-stack-api_server-1 &
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
echo "Finished waiting for service."
- name: Run integration tests
run: |
echo "Running integration tests..."
docker run --rm --network danswer-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
danswer/danswer-integration:test
continue-on-error: true
id: run_tests
- name: Check test results
run: |
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All integration tests passed successfully."
fi
- name: Save Docker logs
if: success() || failure()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@v4
with:
name: docker-logs
path: ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v

View File

@@ -1,68 +0,0 @@
# This workflow is intentionally disabled while we're still working on it
# It's close to ready, but a race condition needs to be fixed with
# API server and Vespa startup, and it needs to have a way to build/test against
# local containers
name: Helm - Lint and Test Charts
on:
merge_group:
pull_request:
branches: [ main ]
jobs:
lint-test:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
# fetch-depth 0 is required for helm/chart-testing-action
steps:
- name: Checkout code
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Set up Helm
uses: azure/setup-helm@v4.2.0
with:
version: v3.14.4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.6.1
- name: Run chart-testing (list-changed)
id: list-changed
run: |
changed=$(ct list-changed --target-branch ${{ github.event.repository.default_branch }})
if [[ -n "$changed" ]]; then
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
- name: Run chart-testing (lint)
# if: steps.list-changed.outputs.changed == 'true'
run: ct lint --all --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
- name: Create kind cluster
# if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@v1.10.0
- name: Run chart-testing (install)
# if: steps.list-changed.outputs.changed == 'true'
run: ct install --all --config ct.yaml
# run: ct install --target-branch ${{ github.event.repository.default_branch }}

View File

@@ -1,23 +1,19 @@
name: Python Checks
on:
merge_group:
pull_request:
branches:
- main
- 'release/**'
branches: [ main ]
jobs:
mypy-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v4
with:
python-version: '3.11'
cache: 'pip'
@@ -27,9 +23,9 @@ jobs:
backend/requirements/model_server.txt
- run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
pip install -r backend/requirements/default.txt
pip install -r backend/requirements/dev.txt
pip install -r backend/requirements/model_server.txt
- name: Run MyPy
run: |

View File

@@ -1,61 +0,0 @@
name: Connector Tests
on:
pull_request:
branches: [main]
schedule:
# This cron expression runs the job daily at 16:00 UTC (9am PT)
- cron: "0 16 * * *"
env:
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
# Jira
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
env:
PYTHONPATH: ./backend
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
cache: "pip"
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors
- name: Alert on Failure
if: failure() && github.event_name == 'schedule'
env:
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
run: |
curl -X POST \
-H 'Content-type: application/json' \
--data '{"text":"Scheduled Connector Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
$SLACK_WEBHOOK

View File

@@ -1,58 +0,0 @@
name: Connector Tests
on:
schedule:
# This cron expression runs the job daily at 16:00 UTC (9am PT)
- cron: "0 16 * * *"
env:
# Bedrock
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
# OpenAI
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
env:
PYTHONPATH: ./backend
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
cache: "pip"
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: |
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/llm
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/embedding
- name: Alert on Failure
if: failure() && github.event_name == 'schedule'
env:
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
run: |
curl -X POST \
-H 'Content-type: application/json' \
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
$SLACK_WEBHOOK

View File

@@ -1,27 +1,22 @@
name: Python Unit Tests
on:
merge_group:
pull_request:
branches:
- main
- 'release/**'
branches: [ main ]
jobs:
backend-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
runs-on: ubuntu-latest
env:
PYTHONPATH: ./backend
REDIS_CLOUD_PYTEST_PASSWORD: ${{ secrets.REDIS_CLOUD_PYTEST_PASSWORD }}
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v4
with:
python-version: '3.11'
cache: 'pip'
@@ -32,8 +27,8 @@ jobs:
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install -r backend/requirements/default.txt
pip install -r backend/requirements/dev.txt
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"

View File

@@ -1,23 +1,21 @@
name: Quality Checks PR
concurrency:
group: Quality-Checks-PR-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
group: Quality-Checks-PR-${{ github.head_ref }}
cancel-in-progress: true
on:
merge_group:
pull_request: null
jobs:
quality-checks:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.1
with:
extra_args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || '' }}
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/setup-python@v5
with:
python-version: '3.11'
- uses: pre-commit/action@v3.0.0
with:
extra_args: --from-ref ${{ github.event.pull_request.base.sha }} --to-ref ${{ github.event.pull_request.head.sha }}

View File

@@ -1,54 +0,0 @@
name: Nightly Tag Push
on:
schedule:
- cron: '0 10 * * *' # Runs every day at 2 AM PST / 3 AM PDT / 10 AM UTC
permissions:
contents: write # Allows pushing tags to the repository
jobs:
create-and-push-tag:
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
# actions using GITHUB_TOKEN cannot trigger another workflow, but we do want this to trigger docker pushes
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
# implement here which needs an actual user's deploy key
- name: Checkout code
uses: actions/checkout@v4
with:
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
- name: Set up Git user
run: |
git config user.name "Richard Kuo [bot]"
git config user.email "rkuo[bot]@danswer.ai"
- name: Check for existing nightly tag
id: check_tag
run: |
if git tag --points-at HEAD --list "nightly-latest*" | grep -q .; then
echo "A tag starting with 'nightly-latest' already exists on HEAD."
echo "tag_exists=true" >> $GITHUB_OUTPUT
else
echo "No tag starting with 'nightly-latest' exists on HEAD."
echo "tag_exists=false" >> $GITHUB_OUTPUT
fi
# don't tag again if HEAD already has a nightly-latest tag on it
- name: Create Nightly Tag
if: steps.check_tag.outputs.tag_exists == 'false'
env:
DATE: ${{ github.run_id }}
run: |
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
echo "Creating tag: $TAG_NAME"
git tag $TAG_NAME
- name: Push Tag
if: steps.check_tag.outputs.tag_exists == 'false'
run: |
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
git push origin $TAG_NAME

2
.gitignore vendored
View File

@@ -4,6 +4,6 @@
.mypy_cache
.idea
/deployment/data/nginx/app.conf
.vscode/
.vscode/launch.json
*.sw?
/backend/tests/regression/answer_quality/search_test_config.yaml

View File

@@ -1 +0,0 @@
backend/tests/integration/tests/pruning/website

View File

@@ -1,5 +1,5 @@
# Copy this file to .env in the .vscode folder
# Fill in the <REPLACE THIS> values as needed, it is recommended to set the GEN_AI_API_KEY value to avoid having to set up an LLM in the UI
# Copy this file to .env at the base of the repo and fill in the <REPLACE THIS> values
# This will help with development iteration speed and reduce repeat tasks for dev
# Also check out danswer/backend/scripts/restart_containers.sh for a script to restart the containers which Danswer relies on outside of VSCode/Cursor processes
# For local dev, often user Authentication is not needed
@@ -15,7 +15,7 @@ LOG_LEVEL=debug
# This passes top N results to LLM an additional time for reranking prior to answer generation
# This step is quite heavy on token usage so we disable it for dev generally
DISABLE_LLM_DOC_RELEVANCE=False
DISABLE_LLM_DOC_RELEVANCE=True
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)
@@ -27,9 +27,9 @@ REQUIRE_EMAIL_VERIFICATION=False
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
GEN_AI_API_KEY=<REPLACE THIS>
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper
GEN_AI_MODEL_VERSION=gpt-4o
FAST_GEN_AI_MODEL_VERSION=gpt-4o
# If answer quality isn't important for dev, use 3.5 turbo due to it being cheaper
GEN_AI_MODEL_VERSION=gpt-3.5-turbo
FAST_GEN_AI_MODEL_VERSION=gpt-3.5-turbo
# For Danswer Slack Bot, overrides the UI values so no need to set this up via UI every time
# Only needed if using DanswerBot
@@ -38,7 +38,7 @@ FAST_GEN_AI_MODEL_VERSION=gpt-4o
# Python stuff
PYTHONPATH=../backend
PYTHONPATH=./backend
PYTHONUNBUFFERED=1
@@ -49,3 +49,4 @@ BING_API_KEY=<REPLACE THIS>
# Enable the full set of Danswer Enterprise Edition features
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False

View File

@@ -1,23 +1,15 @@
/* Copy this file into '.vscode/launch.json' or merge its contents into your existing configurations. */
/*
Copy this file into '.vscode/launch.json' or merge its
contents into your existing configurations.
*/
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"compounds": [
{
"name": "Run All Danswer Services",
"configurations": [
"Web Server",
"Model Server",
"API Server",
"Indexing",
"Background Jobs",
"Slack Bot"
]
}
],
"configurations": [
{
"name": "Web Server",
@@ -25,7 +17,7 @@
"request": "launch",
"cwd": "${workspaceRoot}/web",
"runtimeExecutable": "npm",
"envFile": "${workspaceFolder}/.vscode/.env",
"envFile": "${workspaceFolder}/.env",
"runtimeArgs": [
"run", "dev"
],
@@ -33,12 +25,11 @@
},
{
"name": "Model Server",
"consoleName": "Model Server",
"type": "debugpy",
"type": "python",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
@@ -48,16 +39,16 @@
"--reload",
"--port",
"9000"
]
],
"consoleTitle": "Model Server"
},
{
"name": "API Server",
"consoleName": "API Server",
"type": "debugpy",
"type": "python",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
@@ -68,32 +59,32 @@
"--reload",
"--port",
"8080"
]
],
"consoleTitle": "API Server"
},
{
"name": "Indexing",
"consoleName": "Indexing",
"type": "debugpy",
"type": "python",
"request": "launch",
"program": "danswer/background/update.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"envFile": "${workspaceFolder}/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
}
},
"consoleTitle": "Indexing"
},
// Celery and all async jobs, usually would include indexing as well but this is handled separately above for dev
{
"name": "Background Jobs",
"consoleName": "Background Jobs",
"type": "debugpy",
"type": "python",
"request": "launch",
"program": "scripts/dev_run_background_jobs.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
@@ -102,18 +93,18 @@
},
"args": [
"--no-indexing"
]
],
"consoleTitle": "Background Jobs"
},
// For the listner to access the Slack API,
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
{
"name": "Slack Bot",
"consoleName": "Slack Bot",
"type": "debugpy",
"type": "python",
"request": "launch",
"program": "danswer/danswerbot/slack/listener.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
@@ -122,12 +113,11 @@
},
{
"name": "Pytest",
"consoleName": "Pytest",
"type": "debugpy",
"type": "python",
"request": "launch",
"module": "pytest",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
@@ -138,16 +128,18 @@
// Specify a sepcific module/test to run or provide nothing to run all tests
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
]
},
}
],
"compounds": [
{
"name": "Clear and Restart External Volumes and Containers",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"stopOnEntry": true
"name": "Run Danswer",
"configurations": [
"Web Server",
"Model Server",
"API Server",
"Indexing",
"Background Jobs",
]
}
]
}

View File

@@ -22,7 +22,7 @@ Your input is vital to making sure that Danswer moves in the right direction.
Before starting on implementation, please raise a GitHub issue.
And always feel free to message us (Chris Weaver / Yuhong Sun) on
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ) /
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-2afut44lv-Rw3kSWu6_OmdAXRpCv80DQ) /
[Discord](https://discord.gg/TDJ59cGV2X) directly about anything at all.
@@ -48,26 +48,23 @@ We would love to see you there!
## Get Started 🚀
Danswer being a fully functional app, relies on some external software, specifically:
Danswer being a fully functional app, relies on some external pieces of software, specifically:
- [Postgres](https://www.postgresql.org/) (Relational DB)
- [Vespa](https://vespa.ai/) (Vector DB/Search Engine)
- [Redis](https://redis.io/) (Cache)
- [Nginx](https://nginx.org/) (Not needed for development flows generally)
> **Note:**
> This guide provides instructions to build and run Danswer locally from source with Docker containers providing the above external software. We believe this combination is easier for
> development purposes. If you prefer to use pre-built container images, we provide instructions on running the full Danswer stack within Docker below.
This guide provides instructions to set up the Danswer specific services outside of Docker because it's easier for
development purposes but also feel free to just use the containers and update with local changes by providing the
`--build` flag.
### Local Set Up
Be sure to use Python version 3.11. For instructions on installing Python 3.11 on macOS, refer to the [CONTRIBUTING_MACOS.md](./CONTRIBUTING_MACOS.md) readme.
It is recommended to use Python version 3.11
If using a lower version, modifications will have to be made to the code.
If using a higher version, sometimes some libraries will not be available (i.e. we had problems with Tensorflow in the past with higher versions of python).
If using a higher version, the version of Tensorflow we use may not be available for your platform.
#### Backend: Python requirements
#### Installing Requirements
Currently, we use pip and recommend creating a virtual environment.
For convenience here's a command for it:
@@ -76,9 +73,8 @@ python -m venv .venv
source .venv/bin/activate
```
> **Note:**
> This virtual environment MUST NOT be set up WITHIN the danswer directory if you plan on using mypy within certain IDEs.
> For simplicity, we recommend setting up the virtual environment outside of the danswer directory.
--> Note that this virtual environment MUST NOT be set up WITHIN the danswer
directory
_For Windows, activate the virtual environment using Command Prompt:_
```bash
@@ -93,38 +89,34 @@ Install the required python dependencies:
```bash
pip install -r danswer/backend/requirements/default.txt
pip install -r danswer/backend/requirements/dev.txt
pip install -r danswer/backend/requirements/ee.txt
pip install -r danswer/backend/requirements/model_server.txt
```
Install Playwright for Python (headless browser required by the Web Connector)
In the activated Python virtualenv, install Playwright for Python by running:
```bash
playwright install
```
You may have to deactivate and reactivate your virtualenv for `playwright` to appear on your path.
#### Frontend: Node dependencies
Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend.
Once the above is done, navigate to `danswer/web` run:
```bash
npm i
```
#### Docker containers for external software
You will need Docker installed to run these containers.
Install Playwright (required by the Web Connector)
First navigate to `danswer/deployment/docker_compose`, then start up Postgres/Vespa/Redis with:
> Note: If you have just done the pip install, open a new terminal and source the python virtual-env again.
This will update the path to include playwright
Then install Playwright by running:
```bash
docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational_db cache
playwright install
```
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
#### Running Danswer locally
#### Dependent Docker Containers
First navigate to `danswer/deployment/docker_compose`, then start up Vespa and Postgres with:
```bash
docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational_db
```
(index refers to Vespa and relational_db refers to Postgres)
#### Running Danswer
To start the frontend, navigate to `danswer/web` and run:
```bash
npm run dev
@@ -135,10 +127,11 @@ Navigate to `danswer/backend` and run:
```bash
uvicorn model_server.main:app --reload --port 9000
```
_For Windows (for compatibility with both PowerShell and Command Prompt):_
```bash
powershell -Command "uvicorn model_server.main:app --reload --port 9000"
powershell -Command "
uvicorn model_server.main:app --reload --port 9000
"
```
The first time running Danswer, you will need to run the DB migrations for Postgres.
@@ -161,7 +154,6 @@ To run the backend API server, navigate back to `danswer/backend` and run:
```bash
AUTH_TYPE=disabled uvicorn danswer.main:app --reload --port 8080
```
_For Windows (for compatibility with both PowerShell and Command Prompt):_
```bash
powershell -Command "
@@ -170,58 +162,20 @@ powershell -Command "
"
```
> **Note:**
> If you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services.
#### Wrapping up
You should now have 4 servers running:
- Web server
- Backend API
- Model server
- Background jobs
Now, visit `http://localhost:3000` in your browser. You should see the Danswer onboarding wizard where you can connect your external LLM provider to Danswer.
You've successfully set up a local Danswer instance! 🏁
#### Running the Danswer application in a container
You can run the full Danswer application stack from pre-built images including all external software dependencies.
Navigate to `danswer/deployment/docker_compose` and run:
```bash
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
```
After Docker pulls and starts these containers, navigate to `http://localhost:3000` to use Danswer.
If you want to make changes to Danswer and run those changes in Docker, you can also build a local version of the Danswer container images that incorporates your changes like so:
```bash
docker compose -f docker-compose.dev.yml -p danswer-stack up -d --build
```
Note: if you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services.
### Formatting and Linting
#### Backend
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
First, install pre-commit (if you don't have it already) following the instructions
[here](https://pre-commit.com/#installation).
With the virtual environment active, install the pre-commit library with:
```bash
pip install pre-commit
```
Then, from the `danswer/backend` directory, run:
```bash
pre-commit install
```
Additionally, we use `mypy` for static type checking.
Danswer is fully type-annotated, and we want to keep it that way!
Danswer is fully type-annotated, and we would like to keep it that way!
To run the mypy checks manually, run `python -m mypy .` from the `danswer/backend` directory.
@@ -232,7 +186,6 @@ Please double check that prettier passes before creating a pull request.
### Release Process
Danswer loosely follows the SemVer versioning standard.
Major changes are released with a "minor" version bump. Currently we use patch release versions to indicate small feature changes.
Danswer follows the semver versioning standard.
A set of Docker containers will be pushed automatically to DockerHub with every tag.
You can see the containers [here](https://hub.docker.com/search?q=danswer%2F).

View File

@@ -1,31 +0,0 @@
## Some additional notes for Mac Users
The base instructions to set up the development environment are located in [CONTRIBUTING.md](https://github.com/danswer-ai/danswer/blob/main/CONTRIBUTING.md).
### Setting up Python
Ensure [Homebrew](https://brew.sh/) is already set up.
Then install python 3.11.
```bash
brew install python@3.11
```
Add python 3.11 to your path: add the following line to ~/.zshrc
```
export PATH="$(brew --prefix)/opt/python@3.11/libexec/bin:$PATH"
```
> **Note:**
> You will need to open a new terminal for the path change above to take effect.
### Setting up Docker
On macOS, you will need to install [Docker Desktop](https://www.docker.com/products/docker-desktop/) and
ensure it is running before continuing with the docker commands.
### Formatting and Linting
MacOS will likely require you to remove some quarantine attributes on some of the hooks for them to execute properly.
After installing pre-commit, run the following command:
```bash
sudo xattr -r -d com.apple.quarantine ~/.cache/pre-commit
```

View File

@@ -74,7 +74,7 @@ We also have built-in support for deployment on Kubernetes. Files for that can b
* Organizational understanding and ability to locate and suggest experts from your team.
## Other Notable Benefits of Danswer
## Other Noteable Benefits of Danswer
* User Authentication with document level access management.
* Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
* Admin Dashboard to configure connectors, document-sets, access, etc.

View File

@@ -8,11 +8,8 @@ Edition features outside of personal development or testing purposes. Please rea
founders@danswer.ai for more information. Please visit https://github.com/danswer-ai/danswer"
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
ARG DANSWER_VERSION=0.8-dev
ENV DANSWER_VERSION=${DANSWER_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true"
ARG CA_CERT_CONTENT=""
ARG DANSWER_VERSION=0.3-dev
ENV DANSWER_VERSION=${DANSWER_VERSION}
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
# Install system dependencies
@@ -38,24 +35,11 @@ RUN apt-get update && \
rm -rf /var/lib/apt/lists/* && \
apt-get clean
# Conditionally write the CA certificate and update certificates
RUN if [ -n "$CA_CERT_CONTENT" ]; then \
echo "Adding custom CA certificate"; \
echo "$CA_CERT_CONTENT" > /usr/local/share/ca-certificates/my-ca.crt && \
chmod 644 /usr/local/share/ca-certificates/my-ca.crt && \
update-ca-certificates; \
else \
echo "No custom CA certificate provided"; \
fi
# Install Python dependencies
# Remove py which is pulled in by retry, py is not needed and is a CVE
COPY ./requirements/default.txt /tmp/requirements.txt
COPY ./requirements/ee.txt /tmp/ee-requirements.txt
RUN pip install --no-cache-dir --upgrade \
--retries 5 \
--timeout 30 \
-r /tmp/requirements.txt \
-r /tmp/ee-requirements.txt && \
pip uninstall -y py && \
@@ -91,8 +75,8 @@ Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
# Pre-downloading NLTK for setups with limited egress
RUN python -c "import nltk; \
nltk.download('stopwords', quiet=True); \
nltk.download('wordnet', quiet=True); \
nltk.download('punkt', quiet=True);"
# nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed
# Set up application files
WORKDIR /app
@@ -105,7 +89,6 @@ COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
COPY ./danswer /app/danswer
COPY ./shared_configs /app/shared_configs
COPY ./alembic /app/alembic
COPY ./alembic_tenants /app/alembic_tenants
COPY ./alembic.ini /app/alembic.ini
COPY supervisord.conf /usr/etc/supervisord.conf
@@ -115,7 +98,7 @@ COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connect
# Put logo in assets
COPY ./assets /app/assets
ENV PYTHONPATH=/app
ENV PYTHONPATH /app
# Default command which does nothing
# This container is used by api server and background which specify their own CMD

View File

@@ -7,18 +7,12 @@ You can find it at https://hub.docker.com/r/danswer/danswer-model-server. For mo
visit https://github.com/danswer-ai/danswer."
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
ARG DANSWER_VERSION=0.8-dev
ENV DANSWER_VERSION=${DANSWER_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true"
ARG DANSWER_VERSION=0.3-dev
ENV DANSWER_VERSION=${DANSWER_VERSION}
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
COPY ./requirements/model_server.txt /tmp/requirements.txt
RUN pip install --no-cache-dir --upgrade \
--retries 5 \
--timeout 30 \
-r /tmp/requirements.txt
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt
RUN apt-get remove -y --allow-remove-essential perl-base && \
apt-get autoremove -y
@@ -28,18 +22,14 @@ RUN apt-get remove -y --allow-remove-essential perl-base && \
# Download model weights
# Run Nomic to pull in the custom architecture and have it cached locally
RUN python -c "from transformers import AutoTokenizer; \
AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
AutoTokenizer.from_pretrained('distilbert-base-uncased', cache_folder='/root/.cache/temp_huggingface/hub/'); \
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1', cache_folder='/root/.cache/temp_huggingface/hub/'); \
from huggingface_hub import snapshot_download; \
snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3'); \
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3', cache_dir='/root/.cache/temp_huggingface/hub/'); \
snapshot_download('nomic-ai/nomic-embed-text-v1', cache_dir='/root/.cache/temp_huggingface/hub/'); \
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1', cache_dir='/root/.cache/temp_huggingface/hub/'); \
from sentence_transformers import SentenceTransformer; \
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);"
# In case the user has volumes mounted to /root/.cache/huggingface that they've downloaded while
# running Danswer, don't overwrite it with the built in cache folder
RUN mv /root/.cache/huggingface /root/.cache/temp_huggingface
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True, cache_folder='/root/.cache/temp_huggingface/hub/');"
WORKDIR /app
@@ -55,6 +45,6 @@ COPY ./shared_configs /app/shared_configs
# Model Server main code
COPY ./model_server /app/model_server
ENV PYTHONPATH=/app
ENV PYTHONPATH /app
CMD ["uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000"]

View File

@@ -1,6 +1,6 @@
# A generic, single database configuration.
[DEFAULT]
[alembic]
# path to migration scripts
script_location = alembic
@@ -47,8 +47,7 @@ prepend_sys_path = .
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os
# Use os.pathsep. Default configuration used for new projects.
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
@@ -107,12 +106,3 @@ formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
[alembic]
script_location = alembic
version_locations = %(script_location)s/versions
[schema_private]
script_location = alembic_tenants
version_locations = %(script_location)s/versions

View File

@@ -1,198 +1,86 @@
from sqlalchemy.engine.base import Connection
from typing import Any
import asyncio
from logging.config import fileConfig
import logging
from alembic import context
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.sql import text
from danswer.configs.app_configs import MULTI_TENANT
from danswer.db.engine import build_connection_string
from danswer.db.models import Base
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine
from celery.backends.database.session import ResultModelBase # type: ignore
from danswer.background.celery.celery_app import get_all_tenant_ids
# Alembic Config object
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
if config.config_file_name is not None and config.attributes.get(
"configure_logger", True
):
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# Add your model's MetaData object here for 'autogenerate' support
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = [Base.metadata, ResultModelBase.metadata]
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
# Set up logging
logger = logging.getLogger(__name__)
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def include_object(
object: Any, name: str, type_: str, reflected: bool, compare_to: Any
) -> bool:
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
Determines whether a database object should be included in migrations.
Excludes specified tables from migrations.
"""
if type_ == "table" and name in EXCLUDE_TABLES:
return False
return True
def get_schema_options() -> tuple[str, bool, bool]:
"""
Parses command-line options passed via '-x' in Alembic commands.
Recognizes 'schema', 'create_schema', and 'upgrade_all_tenants' options.
"""
x_args_raw = context.get_x_argument()
x_args = {}
for arg in x_args_raw:
for pair in arg.split(","):
if "=" in pair:
key, value = pair.split("=", 1)
x_args[key.strip()] = value.strip()
schema_name = x_args.get("schema", "public")
create_schema = x_args.get("create_schema", "true").lower() == "true"
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
if MULTI_TENANT and schema_name == "public":
raise ValueError(
"Cannot run default migrations in public schema when multi-tenancy is enabled. "
"Please specify a tenant-specific schema."
)
return schema_name, create_schema, upgrade_all_tenants
def do_run_migrations(
connection: Connection, schema_name: str, create_schema: bool
) -> None:
"""
Executes migrations in the specified schema.
"""
logger.info(f"About to migrate schema: {schema_name}")
if create_schema:
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
connection.execute(text("COMMIT"))
# Set search_path to the target schema
connection.execute(text(f'SET search_path TO "{schema_name}"'))
url = build_connection_string()
context.configure(
connection=connection,
url=url,
target_metadata=target_metadata, # type: ignore
include_object=include_object,
version_table_schema=schema_name,
include_schemas=True,
compare_type=True,
compare_server_default=True,
script_location=config.get_main_option("script_location"),
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
"""
Determines whether to run migrations for a single schema or all schemas,
and executes migrations accordingly.
"""
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
def do_run_migrations(connection: Connection) -> None:
context.configure(connection=connection, target_metadata=target_metadata) # type: ignore
engine = create_async_engine(
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
"""In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = create_async_engine(
build_connection_string(),
poolclass=pool.NullPool,
)
if upgrade_all_tenants:
# Run migrations for all tenant schemas sequentially
tenant_schemas = get_all_tenant_ids()
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
for schema in tenant_schemas:
try:
logger.info(f"Migrating schema: {schema}")
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
schema_name=schema,
create_schema=create_schema,
)
except Exception as e:
logger.error(f"Error migrating schema {schema}: {e}")
raise
else:
try:
logger.info(f"Migrating schema: {schema_name}")
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
schema_name=schema_name,
create_schema=create_schema,
)
except Exception as e:
logger.error(f"Error migrating schema {schema_name}: {e}")
raise
await engine.dispose()
def run_migrations_offline() -> None:
"""
Run migrations in 'offline' mode.
"""
schema_name, _, upgrade_all_tenants = get_schema_options()
url = build_connection_string()
if upgrade_all_tenants:
# Run offline migrations for all tenant schemas
engine = create_async_engine(url)
tenant_schemas = get_all_tenant_ids()
engine.sync_engine.dispose()
for schema in tenant_schemas:
logger.info(f"Migrating schema: {schema}")
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
include_object=include_object,
version_table_schema=schema,
include_schemas=True,
script_location=config.get_main_option("script_location"),
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
else:
logger.info(f"Migrating schema: {schema_name}")
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
include_object=include_object,
version_table_schema=schema_name,
include_schemas=True,
script_location=config.get_main_option("script_location"),
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
await connectable.dispose()
def run_migrations_online() -> None:
"""
Runs migrations in 'online' mode using an asynchronous engine.
"""
"""Run migrations in 'online' mode."""
asyncio.run(run_async_migrations())

View File

@@ -1,27 +0,0 @@
"""add ccpair deletion failure message
Revision ID: 0ebb1d516877
Revises: 52a219fb5233
Create Date: 2024-09-10 15:03:48.233926
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0ebb1d516877"
down_revision = "52a219fb5233"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column("deletion_failure_message", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "deletion_failure_message")

View File

@@ -1,26 +0,0 @@
"""add additional data to notifications
Revision ID: 1b10e1fda030
Revises: 6756efa39ada
Create Date: 2024-10-15 19:26:44.071259
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "1b10e1fda030"
down_revision = "6756efa39ada"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"notification", sa.Column("additional_data", postgresql.JSONB(), nullable=True)
)
def downgrade() -> None:
op.drop_column("notification", "additional_data")

View File

@@ -1,102 +0,0 @@
"""add_user_delete_cascades
Revision ID: 1b8206b29c5d
Revises: 35e6853a51d5
Create Date: 2024-09-18 11:48:59.418726
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "1b8206b29c5d"
down_revision = "35e6853a51d5"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_constraint("credential_user_id_fkey", "credential", type_="foreignkey")
op.create_foreign_key(
"credential_user_id_fkey",
"credential",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
op.drop_constraint("chat_session_user_id_fkey", "chat_session", type_="foreignkey")
op.create_foreign_key(
"chat_session_user_id_fkey",
"chat_session",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
op.drop_constraint("chat_folder_user_id_fkey", "chat_folder", type_="foreignkey")
op.create_foreign_key(
"chat_folder_user_id_fkey",
"chat_folder",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
op.drop_constraint("prompt_user_id_fkey", "prompt", type_="foreignkey")
op.create_foreign_key(
"prompt_user_id_fkey", "prompt", "user", ["user_id"], ["id"], ondelete="CASCADE"
)
op.drop_constraint("notification_user_id_fkey", "notification", type_="foreignkey")
op.create_foreign_key(
"notification_user_id_fkey",
"notification",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
op.drop_constraint("inputprompt_user_id_fkey", "inputprompt", type_="foreignkey")
op.create_foreign_key(
"inputprompt_user_id_fkey",
"inputprompt",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
def downgrade() -> None:
op.drop_constraint("credential_user_id_fkey", "credential", type_="foreignkey")
op.create_foreign_key(
"credential_user_id_fkey", "credential", "user", ["user_id"], ["id"]
)
op.drop_constraint("chat_session_user_id_fkey", "chat_session", type_="foreignkey")
op.create_foreign_key(
"chat_session_user_id_fkey", "chat_session", "user", ["user_id"], ["id"]
)
op.drop_constraint("chat_folder_user_id_fkey", "chat_folder", type_="foreignkey")
op.create_foreign_key(
"chat_folder_user_id_fkey", "chat_folder", "user", ["user_id"], ["id"]
)
op.drop_constraint("prompt_user_id_fkey", "prompt", type_="foreignkey")
op.create_foreign_key("prompt_user_id_fkey", "prompt", "user", ["user_id"], ["id"])
op.drop_constraint("notification_user_id_fkey", "notification", type_="foreignkey")
op.create_foreign_key(
"notification_user_id_fkey", "notification", "user", ["user_id"], ["id"]
)
op.drop_constraint("inputprompt_user_id_fkey", "inputprompt", type_="foreignkey")
op.create_foreign_key(
"inputprompt_user_id_fkey", "inputprompt", "user", ["user_id"], ["id"]
)

View File

@@ -1,135 +0,0 @@
"""embedding model -> search settings
Revision ID: 1f60f60c3401
Revises: f17bf3b0d9f1
Create Date: 2024-08-25 12:39:51.731632
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
# revision identifiers, used by Alembic.
revision = "1f60f60c3401"
down_revision = "f17bf3b0d9f1"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.drop_constraint(
"index_attempt__embedding_model_fk", "index_attempt", type_="foreignkey"
)
# Rename the table
op.rename_table("embedding_model", "search_settings")
# Add new columns
op.add_column(
"search_settings",
sa.Column(
"multipass_indexing", sa.Boolean(), nullable=False, server_default="false"
),
)
op.add_column(
"search_settings",
sa.Column(
"multilingual_expansion",
postgresql.ARRAY(sa.String()),
nullable=False,
server_default="{}",
),
)
op.add_column(
"search_settings",
sa.Column(
"disable_rerank_for_streaming",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
op.add_column(
"search_settings", sa.Column("rerank_model_name", sa.String(), nullable=True)
)
op.add_column(
"search_settings", sa.Column("rerank_provider_type", sa.String(), nullable=True)
)
op.add_column(
"search_settings", sa.Column("rerank_api_key", sa.String(), nullable=True)
)
op.add_column(
"search_settings",
sa.Column(
"num_rerank",
sa.Integer(),
nullable=False,
server_default=str(NUM_POSTPROCESSED_RESULTS),
),
)
# Add the new column as nullable initially
op.add_column(
"index_attempt", sa.Column("search_settings_id", sa.Integer(), nullable=True)
)
# Populate the new column with data from the existing embedding_model_id
op.execute("UPDATE index_attempt SET search_settings_id = embedding_model_id")
# Create the foreign key constraint
op.create_foreign_key(
"fk_index_attempt_search_settings",
"index_attempt",
"search_settings",
["search_settings_id"],
["id"],
)
# Make the new column non-nullable
op.alter_column("index_attempt", "search_settings_id", nullable=False)
# Drop the old embedding_model_id column
op.drop_column("index_attempt", "embedding_model_id")
def downgrade() -> None:
# Add back the embedding_model_id column
op.add_column(
"index_attempt", sa.Column("embedding_model_id", sa.Integer(), nullable=True)
)
# Populate the old column with data from search_settings_id
op.execute("UPDATE index_attempt SET embedding_model_id = search_settings_id")
# Make the old column non-nullable
op.alter_column("index_attempt", "embedding_model_id", nullable=False)
# Drop the foreign key constraint
op.drop_constraint(
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
)
# Drop the new search_settings_id column
op.drop_column("index_attempt", "search_settings_id")
# Rename the table back
op.rename_table("search_settings", "embedding_model")
# Remove added columns
op.drop_column("embedding_model", "num_rerank")
op.drop_column("embedding_model", "rerank_api_key")
op.drop_column("embedding_model", "rerank_provider_type")
op.drop_column("embedding_model", "rerank_model_name")
op.drop_column("embedding_model", "disable_rerank_for_streaming")
op.drop_column("embedding_model", "multilingual_expansion")
op.drop_column("embedding_model", "multipass_indexing")
op.create_foreign_key(
"index_attempt__embedding_model_fk",
"index_attempt",
"embedding_model",
["embedding_model_id"],
["id"],
)

View File

@@ -1,32 +0,0 @@
"""Add Above Below to Persona
Revision ID: 2d2304e27d8c
Revises: 4b08d97e175a
Create Date: 2024-08-21 19:15:15.762948
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2d2304e27d8c"
down_revision = "4b08d97e175a"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column("persona", sa.Column("chunks_above", sa.Integer(), nullable=True))
op.add_column("persona", sa.Column("chunks_below", sa.Integer(), nullable=True))
op.execute(
"UPDATE persona SET chunks_above = 1, chunks_below = 1 WHERE chunks_above IS NULL AND chunks_below IS NULL"
)
op.alter_column("persona", "chunks_above", nullable=False)
op.alter_column("persona", "chunks_below", nullable=False)
def downgrade() -> None:
op.drop_column("persona", "chunks_below")
op.drop_column("persona", "chunks_above")

View File

@@ -1,90 +0,0 @@
"""Add curator fields
Revision ID: 351faebd379d
Revises: ee3f4b47fad5
Create Date: 2024-08-15 22:37:08.397052
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "351faebd379d"
down_revision = "ee3f4b47fad5"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# Add is_curator column to User__UserGroup table
op.add_column(
"user__user_group",
sa.Column("is_curator", sa.Boolean(), nullable=False, server_default="false"),
)
# Use batch mode to modify the enum type
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.alter_column( # type: ignore[attr-defined]
"role",
type_=sa.Enum(
"BASIC",
"ADMIN",
"CURATOR",
"GLOBAL_CURATOR",
name="userrole",
native_enum=False,
),
existing_type=sa.Enum("BASIC", "ADMIN", name="userrole", native_enum=False),
existing_nullable=False,
)
# Create the association table
op.create_table(
"credential__user_group",
sa.Column("credential_id", sa.Integer(), nullable=False),
sa.Column("user_group_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["credential_id"],
["credential.id"],
),
sa.ForeignKeyConstraint(
["user_group_id"],
["user_group.id"],
),
sa.PrimaryKeyConstraint("credential_id", "user_group_id"),
)
op.add_column(
"credential",
sa.Column(
"curator_public", sa.Boolean(), nullable=False, server_default="false"
),
)
def downgrade() -> None:
# Update existing records to ensure they fit within the BASIC/ADMIN roles
op.execute(
"UPDATE \"user\" SET role = 'ADMIN' WHERE role IN ('CURATOR', 'GLOBAL_CURATOR')"
)
# Remove is_curator column from User__UserGroup table
op.drop_column("user__user_group", "is_curator")
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.alter_column( # type: ignore[attr-defined]
"role",
type_=sa.Enum(
"BASIC", "ADMIN", name="userrole", native_enum=False, length=20
),
existing_type=sa.Enum(
"BASIC",
"ADMIN",
"CURATOR",
"GLOBAL_CURATOR",
name="userrole",
native_enum=False,
),
existing_nullable=False,
)
# Drop the association table
op.drop_table("credential__user_group")
op.drop_column("credential", "curator_public")

View File

@@ -1,64 +0,0 @@
"""server default chosen assistants
Revision ID: 35e6853a51d5
Revises: c99d76fcd298
Create Date: 2024-09-13 13:20:32.885317
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "35e6853a51d5"
down_revision = "c99d76fcd298"
branch_labels = None
depends_on = None
DEFAULT_ASSISTANTS = [-2, -1, 0]
def upgrade() -> None:
# Step 1: Update any NULL values to the default value
# This upgrades existing users without ordered assistant
# to have default assistants set to visible assistants which are
# accessible by them.
op.execute(
"""
UPDATE "user" u
SET chosen_assistants = (
SELECT jsonb_agg(
p.id ORDER BY
COALESCE(p.display_priority, 2147483647) ASC,
p.id ASC
)
FROM persona p
LEFT JOIN persona__user pu ON p.id = pu.persona_id AND pu.user_id = u.id
WHERE p.is_visible = true
AND (p.is_public = true OR pu.user_id IS NOT NULL)
)
WHERE chosen_assistants IS NULL
OR chosen_assistants = 'null'
OR jsonb_typeof(chosen_assistants) = 'null'
OR (jsonb_typeof(chosen_assistants) = 'string' AND chosen_assistants = '"null"')
"""
)
# Step 2: Alter the column to make it non-nullable
op.alter_column(
"user",
"chosen_assistants",
type_=postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default=sa.text(f"'{DEFAULT_ASSISTANTS}'::jsonb"),
)
def downgrade() -> None:
op.alter_column(
"user",
"chosen_assistants",
type_=postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
server_default=None,
)

View File

@@ -1,46 +0,0 @@
"""fix_user__external_user_group_id_fk
Revision ID: 46b7a812670f
Revises: f32615f71aeb
Create Date: 2024-09-23 12:58:03.894038
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "46b7a812670f"
down_revision = "f32615f71aeb"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop the existing primary key
op.drop_constraint(
"user__external_user_group_id_pkey",
"user__external_user_group_id",
type_="primary",
)
# Add the new composite primary key
op.create_primary_key(
"user__external_user_group_id_pkey",
"user__external_user_group_id",
["user_id", "external_user_group_id", "cc_pair_id"],
)
def downgrade() -> None:
# Drop the composite primary key
op.drop_constraint(
"user__external_user_group_id_pkey",
"user__external_user_group_id",
type_="primary",
)
# Delete all entries from the table
op.execute("DELETE FROM user__external_user_group_id")
# Recreate the original primary key on user_id
op.create_primary_key(
"user__external_user_group_id_pkey", "user__external_user_group_id", ["user_id"]
)

View File

@@ -1,34 +0,0 @@
"""change default prune_freq
Revision ID: 4b08d97e175a
Revises: d9ec13955951
Create Date: 2024-08-20 15:28:52.993827
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "4b08d97e175a"
down_revision = "d9ec13955951"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.execute(
"""
UPDATE connector
SET prune_freq = 2592000
WHERE prune_freq = 86400
"""
)
def downgrade() -> None:
op.execute(
"""
UPDATE connector
SET prune_freq = 86400
WHERE prune_freq = 2592000
"""
)

View File

@@ -1,66 +0,0 @@
"""Add last synced and last modified to document table
Revision ID: 52a219fb5233
Revises: f7e58d357687
Create Date: 2024-08-28 17:40:46.077470
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import func
# revision identifiers, used by Alembic.
revision = "52a219fb5233"
down_revision = "f7e58d357687"
branch_labels = None
depends_on = None
def upgrade() -> None:
# last modified represents the last time anything needing syncing to vespa changed
# including row metadata and the document itself. This obviously does not include
# the last_synced column.
op.add_column(
"document",
sa.Column(
"last_modified",
sa.DateTime(timezone=True),
nullable=False,
server_default=func.now(),
),
)
# last synced represents the last time this document was synced to Vespa
op.add_column(
"document",
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True),
)
# Set last_synced to the same value as last_modified for existing rows
op.execute(
"""
UPDATE document
SET last_synced = last_modified
"""
)
op.create_index(
op.f("ix_document_last_modified"),
"document",
["last_modified"],
unique=False,
)
op.create_index(
op.f("ix_document_last_synced"),
"document",
["last_synced"],
unique=False,
)
def downgrade() -> None:
op.drop_index(op.f("ix_document_last_synced"), table_name="document")
op.drop_index(op.f("ix_document_last_modified"), table_name="document")
op.drop_column("document", "last_synced")
op.drop_column("document", "last_modified")

View File

@@ -1,79 +0,0 @@
"""assistant_rework
Revision ID: 55546a7967ee
Revises: 61ff3651add4
Create Date: 2024-09-18 17:00:23.755399
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "55546a7967ee"
down_revision = "61ff3651add4"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Reworking persona and user tables for new assistant features
# keep track of user's chosen assistants separate from their `ordering`
op.add_column("persona", sa.Column("builtin_persona", sa.Boolean(), nullable=True))
op.execute("UPDATE persona SET builtin_persona = default_persona")
op.alter_column("persona", "builtin_persona", nullable=False)
op.drop_index("_default_persona_name_idx", table_name="persona")
op.create_index(
"_builtin_persona_name_idx",
"persona",
["name"],
unique=True,
postgresql_where=sa.text("builtin_persona = true"),
)
op.add_column(
"user", sa.Column("visible_assistants", postgresql.JSONB(), nullable=True)
)
op.add_column(
"user", sa.Column("hidden_assistants", postgresql.JSONB(), nullable=True)
)
op.execute(
"UPDATE \"user\" SET visible_assistants = '[]'::jsonb, hidden_assistants = '[]'::jsonb"
)
op.alter_column(
"user",
"visible_assistants",
nullable=False,
server_default=sa.text("'[]'::jsonb"),
)
op.alter_column(
"user",
"hidden_assistants",
nullable=False,
server_default=sa.text("'[]'::jsonb"),
)
op.drop_column("persona", "default_persona")
op.add_column(
"persona", sa.Column("is_default_persona", sa.Boolean(), nullable=True)
)
def downgrade() -> None:
# Reverting changes made in upgrade
op.drop_column("user", "hidden_assistants")
op.drop_column("user", "visible_assistants")
op.drop_index("_builtin_persona_name_idx", table_name="persona")
op.drop_column("persona", "is_default_persona")
op.add_column("persona", sa.Column("default_persona", sa.Boolean(), nullable=True))
op.execute("UPDATE persona SET default_persona = builtin_persona")
op.alter_column("persona", "default_persona", nullable=False)
op.drop_column("persona", "builtin_persona")
op.create_index(
"_default_persona_name_idx",
"persona",
["name"],
unique=True,
postgresql_where=sa.text("default_persona = true"),
)

View File

@@ -1,35 +0,0 @@
"""match_any_keywords flag for standard answers
Revision ID: 5c7fdadae813
Revises: efb35676026c
Create Date: 2024-09-13 18:52:59.256478
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "5c7fdadae813"
down_revision = "efb35676026c"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"standard_answer",
sa.Column(
"match_any_keywords",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("standard_answer", "match_any_keywords")
# ### end Alembic commands ###

View File

@@ -1,30 +0,0 @@
"""add api_version and deployment_name to search settings
Revision ID: 5d12a446f5c0
Revises: e4334d5b33ba
Create Date: 2024-10-08 15:56:07.975636
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "5d12a446f5c0"
down_revision = "e4334d5b33ba"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"embedding_provider", sa.Column("api_version", sa.String(), nullable=True)
)
op.add_column(
"embedding_provider", sa.Column("deployment_name", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("embedding_provider", "deployment_name")
op.drop_column("embedding_provider", "api_version")

View File

@@ -1,162 +0,0 @@
"""Add Permission Syncing
Revision ID: 61ff3651add4
Revises: 1b8206b29c5d
Create Date: 2024-09-05 13:57:11.770413
"""
import fastapi_users_db_sqlalchemy
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "61ff3651add4"
down_revision = "1b8206b29c5d"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Admin user who set up connectors will lose access to the docs temporarily
# only way currently to give back access is to rerun from beginning
op.add_column(
"connector_credential_pair",
sa.Column(
"access_type",
sa.String(),
nullable=True,
),
)
op.execute(
"UPDATE connector_credential_pair SET access_type = 'PUBLIC' WHERE is_public = true"
)
op.execute(
"UPDATE connector_credential_pair SET access_type = 'PRIVATE' WHERE is_public = false"
)
op.alter_column("connector_credential_pair", "access_type", nullable=False)
op.add_column(
"connector_credential_pair",
sa.Column(
"auto_sync_options",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
op.add_column(
"connector_credential_pair",
sa.Column("last_time_perm_sync", sa.DateTime(timezone=True), nullable=True),
)
op.drop_column("connector_credential_pair", "is_public")
op.add_column(
"document",
sa.Column("external_user_emails", postgresql.ARRAY(sa.String()), nullable=True),
)
op.add_column(
"document",
sa.Column(
"external_user_group_ids", postgresql.ARRAY(sa.String()), nullable=True
),
)
op.add_column(
"document",
sa.Column("is_public", sa.Boolean(), nullable=True),
)
op.create_table(
"user__external_user_group_id",
sa.Column(
"user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False
),
sa.Column("external_user_group_id", sa.String(), nullable=False),
sa.Column("cc_pair_id", sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint("user_id"),
)
op.drop_column("external_permission", "user_id")
op.drop_column("email_to_external_user_cache", "user_id")
op.drop_table("permission_sync_run")
op.drop_table("external_permission")
op.drop_table("email_to_external_user_cache")
def downgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column("is_public", sa.BOOLEAN(), nullable=True),
)
op.execute(
"UPDATE connector_credential_pair SET is_public = (access_type = 'PUBLIC')"
)
op.alter_column("connector_credential_pair", "is_public", nullable=False)
op.drop_column("connector_credential_pair", "auto_sync_options")
op.drop_column("connector_credential_pair", "access_type")
op.drop_column("connector_credential_pair", "last_time_perm_sync")
op.drop_column("document", "external_user_emails")
op.drop_column("document", "external_user_group_ids")
op.drop_column("document", "is_public")
op.drop_table("user__external_user_group_id")
# Drop the enum type at the end of the downgrade
op.create_table(
"permission_sync_run",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"source_type",
sa.String(),
nullable=False,
),
sa.Column("update_type", sa.String(), nullable=False),
sa.Column("cc_pair_id", sa.Integer(), nullable=True),
sa.Column(
"status",
sa.String(),
nullable=False,
),
sa.Column("error_msg", sa.Text(), nullable=True),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["cc_pair_id"],
["connector_credential_pair.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"external_permission",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=True),
sa.Column("user_email", sa.String(), nullable=False),
sa.Column(
"source_type",
sa.String(),
nullable=False,
),
sa.Column("external_permission_group", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"email_to_external_user_cache",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("external_user_id", sa.String(), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=True),
sa.Column("user_email", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)

View File

@@ -1,153 +0,0 @@
"""
Revision ID: 6756efa39ada
Revises: 5d12a446f5c0
Create Date: 2024-10-15 17:47:44.108537
"""
from alembic import op
import sqlalchemy as sa
revision = "6756efa39ada"
down_revision = "5d12a446f5c0"
branch_labels = None
depends_on = None
"""
Migrate chat_session and chat_message tables to use UUID primary keys.
This script:
1. Adds UUID columns to chat_session and chat_message
2. Populates new columns with UUIDs
3. Updates foreign key relationships
4. Removes old integer ID columns
Note: Downgrade will assign new integer IDs, not restore original ones.
"""
def upgrade() -> None:
op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto;")
op.add_column(
"chat_session",
sa.Column(
"new_id",
sa.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
),
)
op.execute("UPDATE chat_session SET new_id = gen_random_uuid();")
op.add_column(
"chat_message",
sa.Column("new_chat_session_id", sa.UUID(as_uuid=True), nullable=True),
)
op.execute(
"""
UPDATE chat_message
SET new_chat_session_id = cs.new_id
FROM chat_session cs
WHERE chat_message.chat_session_id = cs.id;
"""
)
op.drop_constraint(
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
)
op.drop_column("chat_message", "chat_session_id")
op.alter_column(
"chat_message", "new_chat_session_id", new_column_name="chat_session_id"
)
op.drop_constraint("chat_session_pkey", "chat_session", type_="primary")
op.drop_column("chat_session", "id")
op.alter_column("chat_session", "new_id", new_column_name="id")
op.create_primary_key("chat_session_pkey", "chat_session", ["id"])
op.create_foreign_key(
"chat_message_chat_session_id_fkey",
"chat_message",
"chat_session",
["chat_session_id"],
["id"],
ondelete="CASCADE",
)
def downgrade() -> None:
op.drop_constraint(
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
)
op.add_column(
"chat_session",
sa.Column("old_id", sa.Integer, autoincrement=True, nullable=True),
)
op.execute("CREATE SEQUENCE chat_session_old_id_seq OWNED BY chat_session.old_id;")
op.execute(
"ALTER TABLE chat_session ALTER COLUMN old_id SET DEFAULT nextval('chat_session_old_id_seq');"
)
op.execute(
"UPDATE chat_session SET old_id = nextval('chat_session_old_id_seq') WHERE old_id IS NULL;"
)
op.alter_column("chat_session", "old_id", nullable=False)
op.drop_constraint("chat_session_pkey", "chat_session", type_="primary")
op.create_primary_key("chat_session_pkey", "chat_session", ["old_id"])
op.add_column(
"chat_message",
sa.Column("old_chat_session_id", sa.Integer, nullable=True),
)
op.execute(
"""
UPDATE chat_message
SET old_chat_session_id = cs.old_id
FROM chat_session cs
WHERE chat_message.chat_session_id = cs.id;
"""
)
op.drop_column("chat_message", "chat_session_id")
op.alter_column(
"chat_message", "old_chat_session_id", new_column_name="chat_session_id"
)
op.create_foreign_key(
"chat_message_chat_session_id_fkey",
"chat_message",
"chat_session",
["chat_session_id"],
["old_id"],
ondelete="CASCADE",
)
op.drop_column("chat_session", "id")
op.alter_column("chat_session", "old_id", new_column_name="id")
op.alter_column(
"chat_session",
"id",
type_=sa.Integer(),
existing_type=sa.Integer(),
existing_nullable=False,
existing_server_default=False,
)
# Rename the sequence
op.execute("ALTER SEQUENCE chat_session_old_id_seq RENAME TO chat_session_id_seq;")
# Update the default value to use the renamed sequence
op.alter_column(
"chat_session",
"id",
server_default=sa.text("nextval('chat_session_id_seq'::regclass)"),
)

View File

@@ -9,7 +9,7 @@ import json
from typing import cast
from alembic import op
import sqlalchemy as sa
from danswer.key_value_store.factory import get_kv_store
from danswer.dynamic_configs.factory import get_dynamic_config_store
# revision identifiers, used by Alembic.
revision = "703313b75876"
@@ -54,7 +54,9 @@ def upgrade() -> None:
)
try:
settings_json = cast(str, get_kv_store().load("token_budget_settings"))
settings_json = cast(
str, get_dynamic_config_store().load("token_budget_settings")
)
settings = json.loads(settings_json)
is_enabled = settings.get("enable_token_budget", False)
@@ -69,7 +71,7 @@ def upgrade() -> None:
)
# Delete the dynamic config
get_kv_store().delete("token_budget_settings")
get_dynamic_config_store().delete("token_budget_settings")
except Exception:
# Ignore if the dynamic config is not found

View File

@@ -10,7 +10,7 @@ import sqlalchemy as sa
from danswer.db.models import IndexModelStatus
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType
from danswer.search.models import SearchType
# revision identifiers, used by Alembic.
revision = "776b3bbe9092"

View File

@@ -1,27 +0,0 @@
"""persona_start_date
Revision ID: 797089dfb4d2
Revises: 55546a7967ee
Create Date: 2024-09-11 14:51:49.785835
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "797089dfb4d2"
down_revision = "55546a7967ee"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"persona",
sa.Column("search_start_date", sa.DateTime(timezone=True), nullable=True),
)
def downgrade() -> None:
op.drop_column("persona", "search_start_date")

View File

@@ -35,22 +35,18 @@ def upgrade() -> None:
op.execute(
"""
UPDATE index_attempt ia
SET connector_credential_pair_id = (
SELECT id FROM connector_credential_pair ccp
WHERE
(ia.connector_id IS NULL OR ccp.connector_id = ia.connector_id)
AND (ia.credential_id IS NULL OR ccp.credential_id = ia.credential_id)
LIMIT 1
)
WHERE ia.connector_id IS NOT NULL OR ia.credential_id IS NOT NULL
"""
)
# For good measure
op.execute(
"""
DELETE FROM index_attempt
WHERE connector_credential_pair_id IS NULL
SET connector_credential_pair_id =
CASE
WHEN ia.credential_id IS NULL THEN
(SELECT id FROM connector_credential_pair
WHERE connector_id = ia.connector_id
LIMIT 1)
ELSE
(SELECT id FROM connector_credential_pair
WHERE connector_id = ia.connector_id
AND credential_id = ia.credential_id)
END
WHERE ia.connector_id IS NOT NULL
"""
)

View File

@@ -1,158 +0,0 @@
"""migration confluence to be explicit
Revision ID: a3795dce87be
Revises: 1f60f60c3401
Create Date: 2024-09-01 13:52:12.006740
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy.sql import table, column
revision = "a3795dce87be"
down_revision = "1f60f60c3401"
branch_labels: None = None
depends_on: None = None
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]:
from urllib.parse import urlparse
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]:
parsed_url = urlparse(wiki_url)
wiki_base = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split('/spaces')[0]}"
path_parts = parsed_url.path.split("/")
space = path_parts[3]
page_id = path_parts[5] if len(path_parts) > 5 else ""
return wiki_base, space, page_id
def _extract_confluence_keys_from_datacenter_url(
wiki_url: str,
) -> tuple[str, str, str]:
DISPLAY = "/display/"
PAGE = "/pages/"
parsed_url = urlparse(wiki_url)
wiki_base = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split(DISPLAY)[0]}"
space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0]
page_id = ""
if (content := parsed_url.path.split(PAGE)) and len(content) > 1:
page_id = content[1]
return wiki_base, space, page_id
is_confluence_cloud = (
".atlassian.net/wiki/spaces/" in wiki_url
or ".jira.com/wiki/spaces/" in wiki_url
)
if is_confluence_cloud:
wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url(wiki_url)
else:
wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url(
wiki_url
)
return wiki_base, space, page_id, is_confluence_cloud
def reconstruct_confluence_url(
wiki_base: str, space: str, page_id: str, is_cloud: bool
) -> str:
if is_cloud:
url = f"{wiki_base}/spaces/{space}"
if page_id:
url += f"/pages/{page_id}"
else:
url = f"{wiki_base}/display/{space}"
if page_id:
url += f"/pages/{page_id}"
return url
def upgrade() -> None:
connector = table(
"connector",
column("id", sa.Integer),
column("source", sa.String()),
column("input_type", sa.String()),
column("connector_specific_config", postgresql.JSONB),
)
# Fetch all Confluence connectors
connection = op.get_bind()
confluence_connectors = connection.execute(
sa.select(connector).where(
sa.and_(
connector.c.source == "CONFLUENCE", connector.c.input_type == "POLL"
)
)
).fetchall()
for row in confluence_connectors:
config = row.connector_specific_config
wiki_page_url = config["wiki_page_url"]
wiki_base, space, page_id, is_cloud = extract_confluence_keys_from_url(
wiki_page_url
)
new_config = {
"wiki_base": wiki_base,
"space": space,
"page_id": page_id,
"is_cloud": is_cloud,
}
for key, value in config.items():
if key not in ["wiki_page_url"]:
new_config[key] = value
op.execute(
connector.update()
.where(connector.c.id == row.id)
.values(connector_specific_config=new_config)
)
def downgrade() -> None:
connector = table(
"connector",
column("id", sa.Integer),
column("source", sa.String()),
column("input_type", sa.String()),
column("connector_specific_config", postgresql.JSONB),
)
confluence_connectors = (
op.get_bind()
.execute(
sa.select(connector).where(
connector.c.source == "CONFLUENCE", connector.c.input_type == "POLL"
)
)
.fetchall()
)
for row in confluence_connectors:
config = row.connector_specific_config
if all(key in config for key in ["wiki_base", "space", "is_cloud"]):
wiki_page_url = reconstruct_confluence_url(
config["wiki_base"],
config["space"],
config.get("page_id", ""),
config["is_cloud"],
)
new_config = {"wiki_page_url": wiki_page_url}
new_config.update(
{
k: v
for k, v in config.items()
if k not in ["wiki_base", "space", "page_id", "is_cloud"]
}
)
op.execute(
connector.update()
.where(connector.c.id == row.id)
.values(connector_specific_config=new_config)
)

View File

@@ -1,27 +0,0 @@
"""add last_pruned to the connector_credential_pair table
Revision ID: ac5eaac849f9
Revises: 52a219fb5233
Create Date: 2024-09-10 15:04:26.437118
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "ac5eaac849f9"
down_revision = "46b7a812670f"
branch_labels = None
depends_on = None
def upgrade() -> None:
# last pruned represents the last time the connector was pruned
op.add_column(
"connector_credential_pair",
sa.Column("last_pruned", sa.DateTime(timezone=True), nullable=True),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "last_pruned")

View File

@@ -1,26 +0,0 @@
"""add support for litellm proxy in reranking
Revision ID: ba98eba0f66a
Revises: bceb1e139447
Create Date: 2024-09-06 10:36:04.507332
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "ba98eba0f66a"
down_revision = "bceb1e139447"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"search_settings", sa.Column("rerank_api_url", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("search_settings", "rerank_api_url")

View File

@@ -1,26 +0,0 @@
"""Add base_url to CloudEmbeddingProvider
Revision ID: bceb1e139447
Revises: a3795dce87be
Create Date: 2024-08-28 17:00:52.554580
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "bceb1e139447"
down_revision = "a3795dce87be"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"embedding_provider", sa.Column("api_url", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("embedding_provider", "api_url")

View File

@@ -1,43 +0,0 @@
"""non nullable default persona
Revision ID: bd2921608c3a
Revises: 797089dfb4d2
Create Date: 2024-09-20 10:28:37.992042
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "bd2921608c3a"
down_revision = "797089dfb4d2"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Set existing NULL values to False
op.execute(
"UPDATE persona SET is_default_persona = FALSE WHERE is_default_persona IS NULL"
)
# Alter the column to be not nullable with a default value of False
op.alter_column(
"persona",
"is_default_persona",
existing_type=sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
)
def downgrade() -> None:
# Revert the changes
op.alter_column(
"persona",
"is_default_persona",
existing_type=sa.Boolean(),
nullable=True,
server_default=None,
)

View File

@@ -1,57 +0,0 @@
"""Add index_attempt_errors table
Revision ID: c5b692fa265c
Revises: 4a951134c801
Create Date: 2024-08-08 14:06:39.581972
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "c5b692fa265c"
down_revision = "4a951134c801"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"index_attempt_errors",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("index_attempt_id", sa.Integer(), nullable=True),
sa.Column("batch", sa.Integer(), nullable=True),
sa.Column(
"doc_summaries",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
),
sa.Column("error_msg", sa.Text(), nullable=True),
sa.Column("traceback", sa.Text(), nullable=True),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["index_attempt_id"],
["index_attempt.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"index_attempt_id",
"index_attempt_errors",
["time_created"],
unique=False,
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("index_attempt_id", table_name="index_attempt_errors")
op.drop_table("index_attempt_errors")
# ### end Alembic commands ###

View File

@@ -1,31 +0,0 @@
"""add nullable to persona id in Chat Session
Revision ID: c99d76fcd298
Revises: 5c7fdadae813
Create Date: 2024-07-09 19:27:01.579697
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c99d76fcd298"
down_revision = "5c7fdadae813"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column(
"chat_session", "persona_id", existing_type=sa.INTEGER(), nullable=True
)
def downgrade() -> None:
op.alter_column(
"chat_session",
"persona_id",
existing_type=sa.INTEGER(),
nullable=False,
)

View File

@@ -1,31 +0,0 @@
"""Remove _alt suffix from model_name
Revision ID: d9ec13955951
Revises: da4c21c69164
Create Date: 2024-08-20 16:31:32.955686
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "d9ec13955951"
down_revision = "da4c21c69164"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.execute(
"""
UPDATE embedding_model
SET model_name = regexp_replace(model_name, '__danswer_alt_index$', '')
WHERE model_name LIKE '%__danswer_alt_index'
"""
)
def downgrade() -> None:
# We can't reliably add the __danswer_alt_index suffix back, so we'll leave this empty
pass

View File

@@ -1,65 +0,0 @@
"""chosen_assistants changed to jsonb
Revision ID: da4c21c69164
Revises: c5b692fa265c
Create Date: 2024-08-18 19:06:47.291491
"""
import json
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "da4c21c69164"
down_revision = "c5b692fa265c"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
conn = op.get_bind()
existing_ids_and_chosen_assistants = conn.execute(
sa.text('select id, chosen_assistants from "user"')
)
op.drop_column(
"user",
"chosen_assistants",
)
op.add_column(
"user",
sa.Column(
"chosen_assistants",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
for id, chosen_assistants in existing_ids_and_chosen_assistants:
conn.execute(
sa.text(
'update "user" set chosen_assistants = :chosen_assistants where id = :id'
),
{"chosen_assistants": json.dumps(chosen_assistants), "id": id},
)
def downgrade() -> None:
conn = op.get_bind()
existing_ids_and_chosen_assistants = conn.execute(
sa.text('select id, chosen_assistants from "user"')
)
op.drop_column(
"user",
"chosen_assistants",
)
op.add_column(
"user",
sa.Column("chosen_assistants", postgresql.ARRAY(sa.Integer()), nullable=True),
)
for id, chosen_assistants in existing_ids_and_chosen_assistants:
conn.execute(
sa.text(
'update "user" set chosen_assistants = :chosen_assistants where id = :id'
),
{"chosen_assistants": chosen_assistants, "id": id},
)

View File

@@ -9,7 +9,7 @@ from alembic import op
import sqlalchemy as sa
from sqlalchemy import table, column, String, Integer, Boolean
from danswer.db.search_settings import (
from danswer.db.embedding_model import (
get_new_default_embedding_model,
get_old_default_embedding_model,
user_has_overridden_embedding_model,
@@ -71,14 +71,14 @@ def upgrade() -> None:
"query_prefix": old_embedding_model.query_prefix,
"passage_prefix": old_embedding_model.passage_prefix,
"index_name": old_embedding_model.index_name,
"status": IndexModelStatus.PRESENT,
"status": old_embedding_model.status,
}
],
)
# if the user has not overridden the default embedding model via env variables,
# insert the new default model into the database to auto-upgrade them
if not user_has_overridden_embedding_model():
new_embedding_model = get_new_default_embedding_model()
new_embedding_model = get_new_default_embedding_model(is_present=False)
op.bulk_insert(
EmbeddingModel,
[

View File

@@ -1,26 +0,0 @@
"""add_deployment_name_to_llmprovider
Revision ID: e4334d5b33ba
Revises: ac5eaac849f9
Create Date: 2024-10-04 09:52:34.896867
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "e4334d5b33ba"
down_revision = "ac5eaac849f9"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"llm_provider", sa.Column("deployment_name", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("llm_provider", "deployment_name")

View File

@@ -0,0 +1,59 @@
"""migrate tool calls
Revision ID: eb690a089310
Revises: ee3f4b47fad5
Create Date: 2024-08-04 17:07:47.533051
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "eb690a089310"
down_revision = "ee3f4b47fad5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create the new column
op.add_column(
"chat_message", sa.Column("tool_call_id", sa.Integer(), nullable=True)
)
op.create_foreign_key(
"fk_chat_message_tool_call",
"chat_message",
"tool_call",
["tool_call_id"],
["id"],
)
# Migrate existing data
op.execute(
"UPDATE chat_message SET tool_call_id = (SELECT id FROM tool_call WHERE tool_call.message_id = chat_message.id LIMIT 1)"
)
# Drop the old relationship
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
op.drop_column("tool_call", "message_id")
def downgrade() -> None:
# Add back the old column
op.add_column(
"tool_call",
sa.Column("message_id", sa.INTEGER(), autoincrement=False, nullable=True),
)
op.create_foreign_key(
"tool_call_message_id_fkey", "tool_call", "chat_message", ["message_id"], ["id"]
)
# Migrate data back
op.execute(
"UPDATE tool_call SET message_id = (SELECT id FROM chat_message WHERE chat_message.tool_call_id = tool_call.id)"
)
# Drop the new column
op.drop_constraint("fk_chat_message_tool_call", "chat_message", type_="foreignkey")
op.drop_column("chat_message", "tool_call_id")

View File

@@ -1,7 +1,7 @@
"""Added alternate model to chat message
Revision ID: ee3f4b47fad5
Revises: 2d2304e27d8c
Revises: 4a951134c801
Create Date: 2024-08-12 00:11:50.915845
"""
@@ -12,17 +12,17 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "ee3f4b47fad5"
down_revision = "2d2304e27d8c"
branch_labels: None = None
depends_on: None = None
down_revision = "4a951134c801"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_message",
sa.Column("overridden_model", sa.String(length=255), nullable=True),
sa.Column("alternate_model", sa.String(length=255), nullable=True),
)
def downgrade() -> None:
op.drop_column("chat_message", "overridden_model")
op.drop_column("chat_message", "alternate_model")

View File

@@ -1,32 +0,0 @@
"""standard answer match_regex flag
Revision ID: efb35676026c
Revises: 0ebb1d516877
Create Date: 2024-09-11 13:55:46.101149
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "efb35676026c"
down_revision = "0ebb1d516877"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"standard_answer",
sa.Column(
"match_regex", sa.Boolean(), nullable=False, server_default=sa.false()
),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("standard_answer", "match_regex")
# ### end Alembic commands ###

View File

@@ -1,172 +0,0 @@
"""embedding provider by provider type
Revision ID: f17bf3b0d9f1
Revises: 351faebd379d
Create Date: 2024-08-21 13:13:31.120460
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f17bf3b0d9f1"
down_revision = "351faebd379d"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# Add provider_type column to embedding_provider
op.add_column(
"embedding_provider",
sa.Column("provider_type", sa.String(50), nullable=True),
)
# Update provider_type with existing name values
op.execute("UPDATE embedding_provider SET provider_type = UPPER(name)")
# Make provider_type not nullable
op.alter_column("embedding_provider", "provider_type", nullable=False)
# Drop the foreign key constraint in embedding_model table
op.drop_constraint(
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
)
# Drop the existing primary key constraint
op.drop_constraint("embedding_provider_pkey", "embedding_provider", type_="primary")
# Create a new primary key constraint on provider_type
op.create_primary_key(
"embedding_provider_pkey", "embedding_provider", ["provider_type"]
)
# Add provider_type column to embedding_model
op.add_column(
"embedding_model",
sa.Column("provider_type", sa.String(50), nullable=True),
)
# Update provider_type for existing embedding models
op.execute(
"""
UPDATE embedding_model
SET provider_type = (
SELECT provider_type
FROM embedding_provider
WHERE embedding_provider.id = embedding_model.cloud_provider_id
)
"""
)
# Drop the old id column from embedding_provider
op.drop_column("embedding_provider", "id")
# Drop the name column from embedding_provider
op.drop_column("embedding_provider", "name")
# Drop the default_model_id column from embedding_provider
op.drop_column("embedding_provider", "default_model_id")
# Drop the old cloud_provider_id column from embedding_model
op.drop_column("embedding_model", "cloud_provider_id")
# Create the new foreign key constraint
op.create_foreign_key(
"fk_embedding_model_cloud_provider",
"embedding_model",
"embedding_provider",
["provider_type"],
["provider_type"],
)
def downgrade() -> None:
# Drop the foreign key constraint in embedding_model table
op.drop_constraint(
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
)
# Add back the cloud_provider_id column to embedding_model
op.add_column(
"embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True)
)
op.add_column("embedding_provider", sa.Column("id", sa.Integer(), nullable=True))
# Assign incrementing IDs to embedding providers
op.execute(
"""
CREATE SEQUENCE IF NOT EXISTS embedding_provider_id_seq;"""
)
op.execute(
"""
UPDATE embedding_provider SET id = nextval('embedding_provider_id_seq');
"""
)
# Update cloud_provider_id based on provider_type
op.execute(
"""
UPDATE embedding_model
SET cloud_provider_id = CASE
WHEN provider_type IS NULL THEN NULL
ELSE (
SELECT id
FROM embedding_provider
WHERE embedding_provider.provider_type = embedding_model.provider_type
)
END
"""
)
# Drop the provider_type column from embedding_model
op.drop_column("embedding_model", "provider_type")
# Add back the columns to embedding_provider
op.add_column("embedding_provider", sa.Column("name", sa.String(50), nullable=True))
op.add_column(
"embedding_provider", sa.Column("default_model_id", sa.Integer(), nullable=True)
)
# Drop the existing primary key constraint on provider_type
op.drop_constraint("embedding_provider_pkey", "embedding_provider", type_="primary")
# Create the original primary key constraint on id
op.create_primary_key("embedding_provider_pkey", "embedding_provider", ["id"])
# Update name with existing provider_type values
op.execute(
"""
UPDATE embedding_provider
SET name = CASE
WHEN provider_type = 'OPENAI' THEN 'OpenAI'
WHEN provider_type = 'COHERE' THEN 'Cohere'
WHEN provider_type = 'GOOGLE' THEN 'Google'
WHEN provider_type = 'VOYAGE' THEN 'Voyage'
ELSE provider_type
END
"""
)
# Drop the provider_type column from embedding_provider
op.drop_column("embedding_provider", "provider_type")
# Recreate the foreign key constraint in embedding_model table
op.create_foreign_key(
"fk_embedding_model_cloud_provider",
"embedding_model",
"embedding_provider",
["cloud_provider_id"],
["id"],
)
# Recreate the foreign key constraint in embedding_model table
op.create_foreign_key(
"fk_embedding_provider_default_model",
"embedding_provider",
"embedding_model",
["default_model_id"],
["id"],
)

View File

@@ -1,26 +0,0 @@
"""add custom headers to tools
Revision ID: f32615f71aeb
Revises: bd2921608c3a
Create Date: 2024-09-12 20:26:38.932377
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "f32615f71aeb"
down_revision = "bd2921608c3a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"tool", sa.Column("custom_headers", postgresql.JSONB(), nullable=True)
)
def downgrade() -> None:
op.drop_column("tool", "custom_headers")

View File

@@ -1,26 +0,0 @@
"""add has_web_login column to user
Revision ID: f7e58d357687
Revises: ba98eba0f66a
Create Date: 2024-09-07 20:20:54.522620
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f7e58d357687"
down_revision = "ba98eba0f66a"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column("has_web_login", sa.Boolean(), nullable=False, server_default="true"),
)
def downgrade() -> None:
op.drop_column("user", "has_web_login")

View File

@@ -1,3 +0,0 @@
These files are for public table migrations when operating with multi tenancy.
If you are not a Danswer developer, you can ignore this directory entirely.

View File

@@ -1,111 +0,0 @@
import asyncio
from logging.config import fileConfig
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.schema import SchemaItem
from alembic import context
from danswer.db.engine import build_connection_string
from danswer.db.models import PublicBase
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None and config.attributes.get(
"configure_logger", True
):
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = [PublicBase.metadata]
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
def include_object(
object: SchemaItem,
name: str,
type_: str,
reflected: bool,
compare_to: SchemaItem | None,
) -> bool:
if type_ == "table" and name in EXCLUDE_TABLES:
return False
return True
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = build_connection_string()
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore
include_object=include_object,
) # type: ignore
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
"""In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = create_async_engine(
build_connection_string(),
poolclass=pool.NullPool,
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
asyncio.run(run_async_migrations())
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -1,24 +0,0 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -1,24 +0,0 @@
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "14a83a331951"
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"user_tenant_mapping",
sa.Column("email", sa.String(), nullable=False),
sa.Column("tenant_id", sa.String(), nullable=False),
sa.UniqueConstraint("email", "tenant_id", name="uq_user_tenant"),
sa.UniqueConstraint("email", name="uq_email"),
schema="public",
)
def downgrade() -> None:
op.drop_table("user_tenant_mapping", schema="public")

View File

@@ -1,81 +1,26 @@
from sqlalchemy.orm import Session
from danswer.access.models import DocumentAccess
from danswer.access.utils import prefix_user_email
from danswer.access.utils import prefix_user
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.db.document import get_access_info_for_document
from danswer.db.document import get_access_info_for_documents
from danswer.db.document import get_acccess_info_for_documents
from danswer.db.models import User
from danswer.utils.variable_functionality import fetch_versioned_implementation
def _get_access_for_document(
document_id: str,
db_session: Session,
) -> DocumentAccess:
info = get_access_info_for_document(
db_session=db_session,
document_id=document_id,
)
return DocumentAccess.build(
user_emails=info[1] if info and info[1] else [],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=info[2] if info else False,
)
def get_access_for_document(
document_id: str,
db_session: Session,
) -> DocumentAccess:
versioned_get_access_for_document_fn = fetch_versioned_implementation(
"danswer.access.access", "_get_access_for_document"
)
return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore
def get_null_document_access() -> DocumentAccess:
return DocumentAccess(
user_emails=set(),
user_groups=set(),
is_public=False,
external_user_emails=set(),
external_user_group_ids=set(),
)
def _get_access_for_documents(
document_ids: list[str],
db_session: Session,
) -> dict[str, DocumentAccess]:
document_access_info = get_access_info_for_documents(
document_access_info = get_acccess_info_for_documents(
db_session=db_session,
document_ids=document_ids,
)
doc_access = {
document_id: DocumentAccess(
user_emails=set([email for email in user_emails if email]),
# MIT version will wipe all groups and external groups on update
user_groups=set(),
is_public=is_public,
external_user_emails=set(),
external_user_group_ids=set(),
)
for document_id, user_emails, is_public in document_access_info
return {
document_id: DocumentAccess.build(user_ids, [], is_public)
for document_id, user_ids, is_public in document_access_info
}
# Sometimes the document has not be indexed by the indexing job yet, in those cases
# the document does not exist and so we use least permissive. Specifically the EE version
# checks the MIT version permissions and creates a superset. This ensures that this flow
# does not fail even if the Document has not yet been indexed.
for doc_id in document_ids:
if doc_id not in doc_access:
doc_access[doc_id] = get_null_document_access()
return doc_access
def get_access_for_documents(
document_ids: list[str],
@@ -97,7 +42,7 @@ def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
matches one entry in the returned set.
"""
if user:
return {prefix_user_email(user.email), PUBLIC_DOC_PAT}
return {prefix_user(str(user.id)), PUBLIC_DOC_PAT}
return {PUBLIC_DOC_PAT}

View File

@@ -1,72 +1,30 @@
from dataclasses import dataclass
from uuid import UUID
from danswer.access.utils import prefix_external_group
from danswer.access.utils import prefix_user_email
from danswer.access.utils import prefix_user
from danswer.access.utils import prefix_user_group
from danswer.configs.constants import PUBLIC_DOC_PAT
@dataclass(frozen=True)
class ExternalAccess:
# Emails of external users with access to the doc externally
external_user_emails: set[str]
# Names or external IDs of groups with access to the doc
external_user_group_ids: set[str]
# Whether the document is public in the external system or Danswer
class DocumentAccess:
user_ids: set[str] # stringified UUIDs
user_groups: set[str] # names of user groups associated with this document
is_public: bool
@dataclass(frozen=True)
class DocumentAccess(ExternalAccess):
# User emails for Danswer users, None indicates admin
user_emails: set[str | None]
# Names of user groups associated with this document
user_groups: set[str]
def to_acl(self) -> set[str]:
return set(
[
prefix_user_email(user_email)
for user_email in self.user_emails
if user_email
]
def to_acl(self) -> list[str]:
return (
[prefix_user(user_id) for user_id in self.user_ids]
+ [prefix_user_group(group_name) for group_name in self.user_groups]
+ [
prefix_user_email(user_email)
for user_email in self.external_user_emails
]
+ [
# The group names are already prefixed by the source type
# This adds an additional prefix of "external_group:"
prefix_external_group(group_name)
for group_name in self.external_user_group_ids
]
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
)
@classmethod
def build(
cls,
user_emails: list[str | None],
user_groups: list[str],
external_user_emails: list[str],
external_user_group_ids: list[str],
is_public: bool,
cls, user_ids: list[UUID | None], user_groups: list[str], is_public: bool
) -> "DocumentAccess":
return cls(
external_user_emails={
prefix_user_email(external_email)
for external_email in external_user_emails
},
external_user_group_ids={
prefix_external_group(external_group_id)
for external_group_id in external_user_group_ids
},
user_emails={
prefix_user_email(user_email)
for user_email in user_emails
if user_email
},
user_ids={str(user_id) for user_id in user_ids if user_id},
user_groups=set(user_groups),
is_public=is_public,
)

View File

@@ -1,24 +1,10 @@
from danswer.configs.constants import DocumentSource
def prefix_user_email(user_email: str) -> str:
"""Prefixes a user email to eliminate collision with group names.
This applies to both a Danswer user and an External user, this is to make the query time
more efficient"""
return f"user_email:{user_email}"
def prefix_user(user_id: str) -> str:
"""Prefixes a user ID to eliminate collision with group names.
This assumes that groups are prefixed with a different prefix."""
return f"user_id:{user_id}"
def prefix_user_group(user_group_name: str) -> str:
"""Prefixes a user group name to eliminate collision with user emails.
"""Prefixes a user group name to eliminate collision with user IDs.
This assumes that user ids are prefixed with a different prefix."""
return f"group:{user_group_name}"
def prefix_external_group(ext_group_name: str) -> str:
"""Prefixes an external group name to eliminate collision with user emails / Danswer groups."""
return f"external_group:{ext_group_name}"
def prefix_group_w_source(ext_group_name: str, source: DocumentSource) -> str:
"""External groups may collide across sources, every source needs its own prefix."""
return f"{source.value.upper()}_{ext_group_name}"

View File

@@ -1,20 +1,20 @@
from typing import cast
from danswer.configs.constants import KV_USER_STORE_KEY
from danswer.key_value_store.factory import get_kv_store
from danswer.key_value_store.interface import JSON_ro
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.dynamic_configs.interface import JSON_ro
def get_invited_users() -> list[str]:
try:
store = get_kv_store()
store = get_dynamic_config_store()
return cast(list, store.load(KV_USER_STORE_KEY))
except KvKeyNotFoundError:
except ConfigNotFoundError:
return list()
def write_invited_users(emails: list[str]) -> int:
store = get_kv_store()
store = get_dynamic_config_store()
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
return len(emails)

View File

@@ -4,29 +4,29 @@ from typing import cast
from danswer.auth.schemas import UserRole
from danswer.configs.constants import KV_NO_AUTH_USER_PREFERENCES_KEY
from danswer.key_value_store.store import KeyValueStore
from danswer.key_value_store.store import KvKeyNotFoundError
from danswer.dynamic_configs.store import ConfigNotFoundError
from danswer.dynamic_configs.store import DynamicConfigStore
from danswer.server.manage.models import UserInfo
from danswer.server.manage.models import UserPreferences
def set_no_auth_user_preferences(
store: KeyValueStore, preferences: UserPreferences
store: DynamicConfigStore, preferences: UserPreferences
) -> None:
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.model_dump())
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.dict())
def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
def load_no_auth_user_preferences(store: DynamicConfigStore) -> UserPreferences:
try:
preferences_data = cast(
Mapping[str, Any], store.load(KV_NO_AUTH_USER_PREFERENCES_KEY)
)
return UserPreferences(**preferences_data)
except KvKeyNotFoundError:
except ConfigNotFoundError:
return UserPreferences(chosen_assistants=None, default_model=None)
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:
def fetch_no_auth_user(store: DynamicConfigStore) -> UserInfo:
return UserInfo(
id="__no_auth_user__",
email="anonymous@danswer.ai",

View File

@@ -5,20 +5,8 @@ from fastapi_users import schemas
class UserRole(str, Enum):
"""
User roles
- Basic can't perform any admin actions
- Admin can perform all admin actions
- Curator can perform admin actions for
groups they are curators of
- Global Curator can perform admin actions
for all groups they are a member of
"""
BASIC = "basic"
ADMIN = "admin"
CURATOR = "curator"
GLOBAL_CURATOR = "global_curator"
class UserStatus(str, Enum):
@@ -33,10 +21,7 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
has_web_login: bool | None = True
tenant_id: str | None = None
class UserUpdate(schemas.BaseUserUpdate):
role: UserRole
has_web_login: bool | None = True

View File

@@ -5,69 +5,41 @@ from datetime import datetime
from datetime import timezone
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import jwt
from email_validator import EmailNotValidError
from email_validator import EmailUndeliverableError
from email_validator import validate_email
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi import Request
from fastapi import Response
from fastapi import status
from fastapi.responses import RedirectResponse
from fastapi.security import OAuth2PasswordRequestForm
from fastapi_users import BaseUserManager
from fastapi_users import exceptions
from fastapi_users import FastAPIUsers
from fastapi_users import models
from fastapi_users import schemas
from fastapi_users import UUIDIDMixin
from fastapi_users.authentication import AuthenticationBackend
from fastapi_users.authentication import CookieTransport
from fastapi_users.authentication import JWTStrategy
from fastapi_users.authentication import Strategy
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
from fastapi_users.authentication.strategy.db import DatabaseStrategy
from fastapi_users.exceptions import UserAlreadyExists
from fastapi_users.jwt import decode_jwt
from fastapi_users.jwt import generate_jwt
from fastapi_users.jwt import SecretType
from fastapi_users.manager import UserManagerDependency
from fastapi_users.openapi import OpenAPIResponseType
from fastapi_users.router.common import ErrorCode
from fastapi_users.router.common import ErrorModel
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import attributes
from sqlalchemy.orm import Session
from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
from danswer.auth.schemas import UserUpdate
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import EMAIL_FROM
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from danswer.configs.app_configs import SECRET_JWT_KEY
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.configs.app_configs import SMTP_PASS
from danswer.configs.app_configs import SMTP_PORT
from danswer.configs.app_configs import SMTP_SERVER
from danswer.configs.app_configs import SMTP_USER
from danswer.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
from danswer.configs.app_configs import WEB_DOMAIN
@@ -79,21 +51,18 @@ from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.auth import SQLAlchemyUserAdminDB
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.engine import get_session
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User
from danswer.db.models import UserTenantMapping
from danswer.db.users import get_user_by_email
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import current_tenant_id
from danswer.utils.variable_functionality import (
fetch_versioned_implementation,
)
logger = setup_logger()
@@ -112,7 +81,7 @@ def verify_auth_setting() -> None:
"User must choose a valid user authentication method: "
"disabled, basic, or google_oauth"
)
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
logger.info(f"Using Auth Type: {AUTH_TYPE.value}")
def get_display_email(email: str | None, space_less: bool = False) -> str:
@@ -137,35 +106,12 @@ def user_needs_to_be_verified() -> bool:
def verify_email_is_invited(email: str) -> None:
whitelist = get_invited_users()
if not whitelist:
return
if not email:
raise PermissionError("Email must be specified")
try:
email_info = validate_email(email)
except EmailUndeliverableError:
raise PermissionError("Email is not valid")
for email_whitelist in whitelist:
try:
# normalized emails are now being inserted into the db
# we can remove this normalization on read after some time has passed
email_info_whitelist = validate_email(email_whitelist)
except EmailNotValidError:
continue
# oddly, normalization does not include lowercasing the user part of the
# email address ... which we want to allow
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
return
raise PermissionError("User not on allowed user whitelist")
if (whitelist and email not in whitelist) or not email:
raise PermissionError("User not on allowed user whitelist")
def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
def verify_email_in_whitelist(email: str) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
if not get_user_by_email(email, db_session):
verify_email_is_invited(email)
@@ -185,20 +131,6 @@ def verify_email_domain(email: str) -> None:
)
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return "public"
# Implement logic to get tenant_id from the mapping table
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
)
tenant_id = result.scalar_one_or_none()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id
def send_user_verification_email(
user_email: str,
token: str,
@@ -232,84 +164,16 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_create: schemas.UC | UserCreate,
safe: bool = False,
request: Optional[Request] = None,
) -> User:
try:
tenant_id = (
get_tenant_id_for_email(user_create.email) if MULTI_TENANT else "public"
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
if not tenant_id:
raise HTTPException(
status_code=401, detail="User does not belong to an organization"
)
async with get_async_session_with_tenant(tenant_id) as db_session:
token = current_tenant_id.set(tenant_id)
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
self.user_db = tenant_user_db
self.database = tenant_user_db
if hasattr(user_create, "role"):
user_count = await get_user_count()
if (
user_count == 0
or user_create.email in get_default_admin_user_emails()
):
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
user = None
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
if (
not user.has_web_login
and hasattr(user_create, "has_web_login")
and user_create.has_web_login
):
user_update = UserUpdate(
password=user_create.password,
has_web_login=True,
role=user_create.role,
is_verified=user_create.is_verified,
)
user = await self.update(user_update, user)
else:
raise exceptions.UserAlreadyExists()
current_tenant_id.reset(token)
return user
async def on_after_login(
self,
user: User,
request: Request | None = None,
response: Response | None = None,
) -> None:
if response is None or not MULTI_TENANT:
return
tenant_id = get_tenant_id_for_email(user.email)
tenant_token = jwt.encode(
{"tenant_id": tenant_id}, SECRET_JWT_KEY, algorithm="HS256"
)
response.set_cookie(
key="tenant_details",
value=tenant_token,
httponly=True,
secure=WEB_DOMAIN.startswith("https"),
samesite="lax",
)
) -> models.UP:
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if hasattr(user_create, "role"):
user_count = await get_user_count()
if user_count == 0 or user_create.email in get_default_admin_user_emails():
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
return await super().create(user_create, safe=safe, request=request) # type: ignore
async def oauth_callback(
self: "BaseUserManager[models.UOAP, models.ID]",
@@ -324,116 +188,33 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> models.UOAP:
# Get tenant_id from mapping table
try:
tenant_id = (
get_tenant_id_for_email(account_email) if MULTI_TENANT else "public"
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
verify_email_in_whitelist(account_email)
verify_email_domain(account_email)
if not tenant_id:
raise HTTPException(status_code=401, detail="User not found")
user = await super().oauth_callback( # type: ignore
oauth_name=oauth_name,
access_token=access_token,
account_id=account_id,
account_email=account_email,
expires_at=expires_at,
refresh_token=refresh_token,
request=request,
associate_by_email=associate_by_email,
is_verified_by_default=is_verified_by_default,
)
token = None
async with get_async_session_with_tenant(tenant_id) as db_session:
token = current_tenant_id.set(tenant_id)
verify_email_in_whitelist(account_email, tenant_id)
verify_email_domain(account_email)
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
self.user_db = tenant_user_db
self.database = tenant_user_db # type: ignore
oauth_account_dict = {
"oauth_name": oauth_name,
"access_token": access_token,
"account_id": account_id,
"account_email": account_email,
"expires_at": expires_at,
"refresh_token": refresh_token,
}
try:
# Attempt to get user by OAuth account
user = await self.get_by_oauth_account(oauth_name, account_id)
except exceptions.UserNotExists:
try:
# Attempt to get user by email
user = await self.get_by_email(account_email)
if not associate_by_email:
raise exceptions.UserAlreadyExists()
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
# If user not found by OAuth account or email, create a new user
except exceptions.UserNotExists:
password = self.password_helper.generate()
user_dict = {
"email": account_email,
"hashed_password": self.password_helper.hash(password),
"is_verified": is_verified_by_default,
}
user = await self.user_db.create(user_dict)
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
await self.on_after_register(user, request)
else:
for existing_oauth_account in user.oauth_accounts:
if (
existing_oauth_account.account_id == account_id
and existing_oauth_account.oauth_name == oauth_name
):
user = await self.user_db.update_oauth_account(
user, existing_oauth_account, oauth_account_dict
)
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
# re-authenticate that frequently, so by default this is disabled
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
await self.user_db.update(
user, update_dict={"oidc_expiry": oidc_expiry}
)
# Handle case where user has used product outside of web and is now creating an account through web
if not user.has_web_login: # type: ignore
await self.user_db.update(
user,
{
"is_verified": is_verified_by_default,
"has_web_login": True,
},
)
user.is_verified = is_verified_by_default
user.has_web_login = True # type: ignore
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
# otherwise, the oidc expiry will always be old, and the user will never be able to login
if (
user.oidc_expiry is not None # type: ignore
and not TRACK_EXTERNAL_IDP_EXPIRY
):
await self.user_db.update(user, {"oidc_expiry": None})
user.oidc_expiry = None # type: ignore
if token:
current_tenant_id.reset(token)
return user
# NOTE: google oauth expires after 1hr. We don't want to force the user to
# re-authenticate that frequently, so for now we'll just ignore this for
# google oauth users
if expires_at and AUTH_TYPE != AuthType.GOOGLE_OAUTH:
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
await self.user_db.update(user, update_dict={"oidc_expiry": oidc_expiry})
return user
async def on_after_register(
self, user: User, request: Optional[Request] = None
) -> None:
logger.notice(f"User {user.id} has registered.")
logger.info(f"User {user.id} has registered.")
optional_telemetry(
record_type=RecordType.SIGN_UP,
data={"action": "create"},
@@ -443,67 +224,19 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
async def on_after_forgot_password(
self, user: User, token: str, request: Optional[Request] = None
) -> None:
logger.notice(f"User {user.id} has forgot their password. Reset token: {token}")
logger.info(f"User {user.id} has forgot their password. Reset token: {token}")
async def on_after_request_verify(
self, user: User, token: str, request: Optional[Request] = None
) -> None:
verify_email_domain(user.email)
logger.notice(
logger.info(
f"Verification requested for user {user.id}. Verification token: {token}"
)
send_user_verification_email(user.email, token)
async def authenticate(
self, credentials: OAuth2PasswordRequestForm
) -> Optional[User]:
email = credentials.username
# Get tenant_id from mapping table
tenant_id = get_tenant_id_for_email(email)
if not tenant_id:
# User not found in mapping
self.password_helper.hash(credentials.password)
return None
# Create a tenant-specific session
async with get_async_session_with_tenant(tenant_id) as tenant_session:
tenant_user_db: SQLAlchemyUserDatabase = SQLAlchemyUserDatabase(
tenant_session, User
)
self.user_db = tenant_user_db
# Proceed with authentication
try:
user = await self.get_by_email(email)
except exceptions.UserNotExists:
self.password_helper.hash(credentials.password)
return None
has_web_login = attributes.get_attribute(user, "has_web_login")
if not has_web_login:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
)
verified, updated_password_hash = self.password_helper.verify_and_update(
credentials.password, user.hashed_password
)
if not verified:
return None
if updated_password_hash is not None:
await self.user_db.update(
user, {"hashed_password": updated_password_hash}
)
return user
async def get_user_manager(
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
@@ -517,26 +250,21 @@ cookie_transport = CookieTransport(
)
def get_jwt_strategy() -> JWTStrategy:
return JWTStrategy(
secret=USER_AUTH_SECRET,
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
)
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> DatabaseStrategy:
return DatabaseStrategy(
strategy = DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
)
return strategy
auth_backend = AuthenticationBackend(
name="jwt" if MULTI_TENANT else "database",
name="database",
transport=cookie_transport,
get_strategy=get_jwt_strategy if MULTI_TENANT else get_database_strategy, # type: ignore
) # type: ignore
get_strategy=get_database_strategy,
)
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
@@ -550,11 +278,9 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
This way the login router does not need to be included
"""
router = APIRouter()
get_current_user_token = self.authenticator.current_user_token(
active=True, verified=requires_verification
)
logout_responses: OpenAPIResponseType = {
**{
status.HTTP_401_UNAUTHORIZED: {
@@ -601,8 +327,8 @@ async def optional_user_(
async def optional_user(
request: Request,
db_session: Session = Depends(get_session),
user: User | None = Depends(optional_fastapi_current_user),
db_session: Session = Depends(get_session),
) -> User | None:
versioned_fetch_user = fetch_versioned_implementation(
"danswer.auth.users", "optional_user_"
@@ -613,7 +339,6 @@ async def optional_user(
async def double_check_user(
user: User | None,
optional: bool = DISABLE_AUTH,
include_expired: bool = False,
) -> User | None:
if optional:
return None
@@ -630,11 +355,7 @@ async def double_check_user(
detail="Access denied. User is not verified.",
)
if (
user.oidc_expiry
and user.oidc_expiry < datetime.now(timezone.utc)
and not include_expired
):
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User's OIDC token has expired.",
@@ -643,40 +364,12 @@ async def double_check_user(
return user
async def current_user_with_expired_token(
user: User | None = Depends(optional_user),
) -> User | None:
return await double_check_user(user, include_expired=True)
async def current_user(
user: User | None = Depends(optional_user),
) -> User | None:
return await double_check_user(user)
async def current_curator_or_admin_user(
user: User | None = Depends(current_user),
) -> User | None:
if DISABLE_AUTH:
return None
if not user or not hasattr(user, "role"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not authenticated or lacks role information.",
)
allowed_roles = {UserRole.GLOBAL_CURATOR, UserRole.CURATOR, UserRole.ADMIN}
if user.role not in allowed_roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not a curator or admin.",
)
return user
async def current_admin_user(user: User | None = Depends(current_user)) -> User | None:
if DISABLE_AUTH:
return None
@@ -684,195 +377,7 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User must be an admin to perform this action.",
detail="Access denied. User is not an admin.",
)
return user
def get_default_admin_user_emails_() -> list[str]:
# No default seeding available for Danswer MIT
return []
STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
class OAuth2AuthorizeResponse(BaseModel):
authorization_url: str
def generate_state_token(
data: Dict[str, str], secret: SecretType, lifetime_seconds: int = 3600
) -> str:
data["aud"] = STATE_TOKEN_AUDIENCE
return generate_jwt(data, secret, lifetime_seconds)
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
def create_danswer_oauth_router(
oauth_client: BaseOAuth2,
backend: AuthenticationBackend,
state_secret: SecretType,
redirect_url: Optional[str] = None,
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> APIRouter:
return get_oauth_router(
oauth_client,
backend,
get_user_manager,
state_secret,
redirect_url,
associate_by_email,
is_verified_by_default,
)
def get_oauth_router(
oauth_client: BaseOAuth2,
backend: AuthenticationBackend,
get_user_manager: UserManagerDependency[models.UP, models.ID],
state_secret: SecretType,
redirect_url: Optional[str] = None,
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> APIRouter:
"""Generate a router with the OAuth routes."""
router = APIRouter()
callback_route_name = f"oauth:{oauth_client.name}.{backend.name}.callback"
if redirect_url is not None:
oauth2_authorize_callback = OAuth2AuthorizeCallback(
oauth_client,
redirect_url=redirect_url,
)
else:
oauth2_authorize_callback = OAuth2AuthorizeCallback(
oauth_client,
route_name=callback_route_name,
)
@router.get(
"/authorize",
name=f"oauth:{oauth_client.name}.{backend.name}.authorize",
response_model=OAuth2AuthorizeResponse,
)
async def authorize(
request: Request, scopes: List[str] = Query(None)
) -> OAuth2AuthorizeResponse:
if redirect_url is not None:
authorize_redirect_url = redirect_url
else:
authorize_redirect_url = str(request.url_for(callback_route_name))
next_url = request.query_params.get("next", "/")
state_data: Dict[str, str] = {"next_url": next_url}
state = generate_state_token(state_data, state_secret)
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url,
state,
scopes,
)
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
@router.get(
"/callback",
name=callback_route_name,
description="The response varies based on the authentication backend used.",
responses={
status.HTTP_400_BAD_REQUEST: {
"model": ErrorModel,
"content": {
"application/json": {
"examples": {
"INVALID_STATE_TOKEN": {
"summary": "Invalid state token.",
"value": None,
},
ErrorCode.LOGIN_BAD_CREDENTIALS: {
"summary": "User is inactive.",
"value": {"detail": ErrorCode.LOGIN_BAD_CREDENTIALS},
},
}
}
},
},
},
)
async def callback(
request: Request,
access_token_state: Tuple[OAuth2Token, str] = Depends(
oauth2_authorize_callback
),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
) -> RedirectResponse:
token, state = access_token_state
account_id, account_email = await oauth_client.get_id_email(
token["access_token"]
)
if account_email is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL,
)
try:
state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE])
except jwt.DecodeError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
next_url = state_data.get("next_url", "/")
# Authenticate user
try:
user = await user_manager.oauth_callback(
oauth_client.name,
token["access_token"],
account_id,
account_email,
token.get("expires_at"),
token.get("refresh_token"),
request,
associate_by_email=associate_by_email,
is_verified_by_default=is_verified_by_default,
)
except UserAlreadyExists:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.OAUTH_USER_ALREADY_EXISTS,
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.LOGIN_BAD_CREDENTIALS,
)
# Login user
response = await backend.login(strategy, user)
await user_manager.on_after_login(user, request, response)
# Prepare redirect response
redirect_response = RedirectResponse(next_url, status_code=302)
# Copy headers and other attributes from 'response' to 'redirect_response'
for header_name, header_value in response.headers.items():
redirect_response.headers[header_name] = header_value
if hasattr(response, "body"):
redirect_response.body = response.body
if hasattr(response, "status_code"):
redirect_response.status_code = response.status_code
if hasattr(response, "media_type"):
redirect_response.media_type = response.media_type
return redirect_response
return router

View File

@@ -1,619 +1,343 @@
import logging
import multiprocessing
import time
from datetime import timedelta
from typing import Any
from typing import cast
import sentry_sdk
from celery import bootsteps # type: ignore
from celery import Celery
from celery import current_task
from celery import signals
from celery import Task
from celery.exceptions import WorkerShutdown
from celery.signals import beat_init
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
from celery.states import READY_STATES
from celery.utils.log import get_task_logger
from sentry_sdk.integrations.celery import CeleryIntegration
from celery import Celery # type: ignore
from sqlalchemy.orm import Session
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.background.celery.celery_utils import get_all_tenant_ids
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import SqlEngine
from danswer.db.search_settings import get_current_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import ColoredFormatter
from danswer.utils.logger import PlainFormatter
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.celery_utils import should_kick_off_deletion_of_cc_pair
from danswer.background.celery.celery_utils import should_prune_cc_pair
from danswer.background.celery.celery_utils import should_sync_doc_set
from danswer.background.connector_deletion import delete_connector_credential_pair
from danswer.background.connector_deletion import delete_connector_credential_pair_batch
from danswer.background.task_utils import build_celery_task_wrapper
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import POSTGRES_CELERY_APP_NAME
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import InputType
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import delete_document_set
from danswer.db.document_set import fetch_document_sets
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.document_set import fetch_documents_for_document_set_paginated
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import build_connection_string
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import SYNC_DB_API
from danswer.db.models import DocumentSet
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import UpdateRequest
from danswer.utils.logger import setup_logger
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
# use this within celery tasks to get celery task specific logging
task_logger = get_task_logger(__name__)
if SENTRY_DSN:
sentry_sdk.init(
dsn=SENTRY_DSN,
integrations=[CeleryIntegration()],
traces_sample_rate=0.5,
)
logger.info("Sentry initialized")
else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
celery_app = Celery(__name__)
celery_app.config_from_object(
"danswer.background.celery.celeryconfig"
) # Load configuration from 'celeryconfig.py'
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
tenant_id: str | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
pass
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict[str, Any] | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
"""We handle this signal in order to remove completed tasks
from their respective tasksets. This allows us to track the progress of document set
and user group syncs.
This function runs after any task completes (both success and failure)
Note that this signal does not fire on a task that failed to complete and is going
to be retried.
This also does not fire if a worker with acks_late=False crashes (which all of our
long running workers are)
"""
if not task:
return
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
if not kwargs:
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
tenant_id = None
else:
tenant_id = kwargs.get("tenant_id")
task_logger.debug(
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
f"{f'for tenant_id={tenant_id}' if tenant_id else ''}"
)
if state not in READY_STATES:
return
if not task_id:
return
r = get_redis_client(tenant_id=tenant_id)
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
return
if task_id.startswith(RedisDocumentSet.PREFIX):
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
if document_set_id is not None:
rds = RedisDocumentSet(int(document_set_id))
r.srem(rds.taskset_key, task_id)
return
if task_id.startswith(RedisUserGroup.PREFIX):
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
if usergroup_id is not None:
rug = RedisUserGroup(int(usergroup_id))
r.srem(rug.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorDeletion.PREFIX):
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
if cc_pair_id is not None:
rcd = RedisConnectorDeletion(int(cc_pair_id))
r.srem(rcd.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
if cc_pair_id is not None:
rcp = RedisConnectorPruning(int(cc_pair_id))
r.srem(rcp.taskset_key, task_id)
return
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
"""The first signal sent on celery worker startup"""
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
@beat_init.connect
def on_beat_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
SqlEngine.init_engine(pool_size=2, max_overflow=0)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
# decide some initial startup settings based on the celery worker's hostname
# (set at the command line)'
hostname = sender.hostname
if hostname.startswith("light"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
elif hostname.startswith("heavy"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
elif hostname.startswith("indexing"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
# TODO: why is this necessary for the indexer to do?
with get_session_with_tenant(tenant_id) as db_session:
check_index_swap(db_session=db_session)
search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
if search_settings.provider_type is None:
logger.notice(
"Running a first inference to warm up embedding model"
)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(
embedding_model=embedding_model,
)
logger.notice("First inference complete.")
else:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
if not hasattr(sender, "primary_worker_locks"):
sender.primary_worker_locks = {}
tenant_ids = get_all_tenant_ids()
if not celery_is_worker_primary(sender):
logger.info("Running as a secondary celery worker.")
for tenant_id in tenant_ids:
r = get_redis_client(tenant_id=tenant_id)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
time_start = time.monotonic()
logger.notice("Redis: Readiness check starting.")
while True:
# Log all the locks in Redis
all_locks = r.keys("*")
logger.notice(f"Current Redis locks: {all_locks}")
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
break
time_elapsed = time.monotonic() - time_start
logger.info(
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
"Redis: Readiness check did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
time.sleep(WAIT_INTERVAL)
logger.info("Wait for primary worker completed successfully. Continuing...")
return # Exit the function for secondary workers
for tenant_id in tenant_ids:
r = get_redis_client(tenant_id=tenant_id)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
time_start = time.monotonic()
logger.info("Running as the primary celery worker.")
# This is singleton work that should be done on startup exactly once
# by the primary worker
r = get_redis_client(tenant_id=tenant_id)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
# this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")
sender.primary_worker_locks[tenant_id] = lock
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
r.delete(key)
# @worker_process_init.connect
# def on_worker_process_init(sender: Any, **kwargs: Any) -> None:
# """This only runs inside child processes when the worker is in pool=prefork mode.
# This may be technically unnecessary since we're finding prefork pools to be
# unstable and currently aren't planning on using them."""
# logger.info("worker_process_init signal received.")
# SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
# SqlEngine.init_engine(pool_size=5, max_overflow=0)
# # https://stackoverflow.com/questions/43944787/sqlalchemy-celery-with-scoped-session-error
# SqlEngine.get_engine().dispose(close=False)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
task_logger.info("worker_ready signal received.")
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
if not celery_is_worker_primary(sender):
return
if not hasattr(sender, "primary_worker_locks"):
return
logger.info("Releasing primary worker lock.")
for tenant_id, lock in sender.primary_worker_locks.items():
logger.info(f"Releasing primary worker lock for tenant {tenant_id}.")
if lock.owned():
lock.release()
sender.primary_worker_locks = {}
class CeleryTaskPlainFormatter(PlainFormatter):
def format(self, record: logging.LogRecord) -> str:
task = current_task
if task and task.request:
record.__dict__.update(task_id=task.request.id, task_name=task.name)
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
return super().format(record)
class CeleryTaskColoredFormatter(ColoredFormatter):
def format(self, record: logging.LogRecord) -> str:
task = current_task
if task and task.request:
record.__dict__.update(task_id=task.request.id, task_name=task.name)
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
return super().format(record)
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
# TODO: could unhardcode format and colorize and accept these as options from
# celery's config
# reformats the root logger
root_logger = logging.getLogger()
root_handler = logging.StreamHandler() # Set up a handler for the root logger
root_formatter = ColoredFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
root_handler.setFormatter(root_formatter)
root_logger.addHandler(root_handler) # Apply the handler to the root logger
if logfile:
root_file_handler = logging.FileHandler(logfile)
root_file_formatter = PlainFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
root_file_handler.setFormatter(root_file_formatter)
root_logger.addHandler(root_file_handler)
root_logger.setLevel(loglevel)
# reformats celery's task logger
task_formatter = CeleryTaskColoredFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
task_handler = logging.StreamHandler() # Set up a handler for the task logger
task_handler.setFormatter(task_formatter)
task_logger.addHandler(task_handler) # Apply the handler to the task logger
if logfile:
task_file_handler = logging.FileHandler(logfile)
task_file_formatter = CeleryTaskPlainFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
task_file_handler.setFormatter(task_file_formatter)
task_logger.addHandler(task_file_handler)
task_logger.setLevel(loglevel)
task_logger.propagate = False
class HubPeriodicTask(bootsteps.StartStopStep):
"""Regularly reacquires the primary worker locks for all tenants outside of the task queue.
Use the task_logger in this class to avoid double logging.
This cannot be done inside a regular beat task because it must run on schedule and
a queue of existing work would starve the task from running.
"""
# Requires the Hub component
requires = {"celery.worker.components:Hub"}
def __init__(self, worker: Any, **kwargs: Any) -> None:
super().__init__(worker, **kwargs)
self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds
self.task_tref = None
def start(self, worker: Any) -> None:
if not celery_is_worker_primary(worker):
return
# Access the worker's event loop (hub)
hub = worker.consumer.controller.hub
# Schedule the periodic task
self.task_tref = hub.call_repeatedly(
self.interval, self.run_periodic_task, worker
)
task_logger.info("Scheduled periodic task with hub.")
def run_periodic_task(self, worker: Any) -> None:
try:
if not celery_is_worker_primary(worker):
return
if not hasattr(worker, "primary_worker_locks"):
return
# Retrieve all tenant IDs
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
lock = worker.primary_worker_locks.get(tenant_id)
if not lock:
continue # Skip if no lock for this tenant
r = get_redis_client(tenant_id=tenant_id)
if lock.owned():
task_logger.debug(
f"Reacquiring primary worker lock for tenant {tenant_id}."
)
lock.reacquire()
else:
task_logger.warning(
f"Full acquisition of primary worker lock for tenant {tenant_id}. "
"Reasons could be worker restart or lock expiration."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire starting."
)
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire succeeded."
)
worker.primary_worker_locks[tenant_id] = lock
else:
task_logger.error(
f"Primary worker lock for tenant {tenant_id}: Acquire failed!"
)
raise TimeoutError(
f"Primary worker lock for tenant {tenant_id} could not be acquired!"
)
except Exception as e:
task_logger.error(f"Error in periodic task: {e}")
def stop(self, worker: Any) -> None:
# Cancel the scheduled task when the worker stops
if self.task_tref:
self.task_tref.cancel()
task_logger.info("Canceled periodic task with hub.")
celery_app.steps["worker"].add(HubPeriodicTask)
celery_app.autodiscover_tasks(
[
"danswer.background.celery.tasks.connector_deletion",
"danswer.background.celery.tasks.indexing",
"danswer.background.celery.tasks.periodic",
"danswer.background.celery.tasks.pruning",
"danswer.background.celery.tasks.shared",
"danswer.background.celery.tasks.vespa",
]
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=POSTGRES_CELERY_APP_NAME
)
celery_broker_url = f"sqla+{connection_string}"
celery_backend_url = f"db+{connection_string}"
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
_SYNC_BATCH_SIZE = 100
#####
# Tasks that need to be run in job queue, registered via APIs
#
# If imports from this module are needed, use local imports to avoid circular importing
#####
@build_celery_task_wrapper(name_cc_cleanup_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def cleanup_connector_credential_pair_task(
connector_id: int,
credential_id: int,
) -> int:
"""Connector deletion task. This is run as an async task because it is a somewhat slow job.
Needs to potentially update a large number of Postgres and Vespa docs, including deleting them
or updating the ACL"""
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
# validate that the connector / credential pair is deletable
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair:
raise ValueError(
f"Cannot run deletion attempt - connector_credential_pair with Connector ID: "
f"{connector_id} and Credential ID: {credential_id} does not exist."
)
deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed(
connector_credential_pair=cc_pair, db_session=db_session
)
if deletion_attempt_disallowed_reason:
raise ValueError(deletion_attempt_disallowed_reason)
try:
# The bulk of the work is in here, updates Postgres and Vespa
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
return delete_connector_credential_pair(
db_session=db_session,
document_index=document_index,
cc_pair=cc_pair,
)
except Exception as e:
logger.exception(f"Failed to run connector_deletion due to {e}")
raise e
@build_celery_task_wrapper(name_cc_prune_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def prune_documents_task(connector_id: int, credential_id: int) -> None:
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
with Session(get_sqlalchemy_engine()) as db_session:
try:
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair:
logger.warning(f"ccpair not found for {connector_id} {credential_id}")
return
runnable_connector = instantiate_connector(
cc_pair.connector.source,
InputType.PRUNE,
cc_pair.connector.connector_specific_config,
cc_pair.credential,
db_session,
)
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
runnable_connector
)
all_indexed_document_ids = {
doc.id
for doc in get_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
}
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
if len(doc_ids_to_remove) == 0:
logger.info(
f"No docs to prune from {cc_pair.connector.source} connector"
)
return
logger.info(
f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector"
)
delete_connector_credential_pair_batch(
document_ids=doc_ids_to_remove,
connector_id=connector_id,
credential_id=credential_id,
document_index=document_index,
)
except Exception as e:
logger.exception(
f"Failed to run pruning for connector id {connector_id} due to {e}"
)
raise e
@build_celery_task_wrapper(name_document_set_sync_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_document_set_task(document_set_id: int) -> None:
"""For document sets marked as not up to date, sync the state from postgres
into the datastore. Also handles deletions."""
def _sync_document_batch(document_ids: list[str], db_session: Session) -> None:
logger.debug(f"Syncing document sets for: {document_ids}")
# Acquires a lock on the documents so that no other process can modify them
with prepare_to_modify_documents(
db_session=db_session, document_ids=document_ids
):
# get current state of document sets for these documents
document_set_map = {
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
document_ids=document_ids, db_session=db_session
)
}
# update Vespa
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
update_requests = [
UpdateRequest(
document_ids=[document_id],
document_sets=set(document_set_map.get(document_id, [])),
)
for document_id in document_ids
]
document_index.update(update_requests=update_requests)
with Session(get_sqlalchemy_engine()) as db_session:
try:
cursor = None
while True:
document_batch, cursor = fetch_documents_for_document_set_paginated(
document_set_id=document_set_id,
db_session=db_session,
current_only=False,
last_document_id=cursor,
limit=_SYNC_BATCH_SIZE,
)
_sync_document_batch(
document_ids=[document.id for document in document_batch],
db_session=db_session,
)
if cursor is None:
break
# if there are no connectors, then delete the document set. Otherwise, just
# mark it as successfully synced.
document_set = cast(
DocumentSet,
get_document_set_by_id(
db_session=db_session, document_set_id=document_set_id
),
) # casting since we "know" a document set with this ID exists
if not document_set.connector_credential_pairs:
delete_document_set(
document_set_row=document_set, db_session=db_session
)
logger.info(
f"Successfully deleted document set with ID: '{document_set_id}'!"
)
else:
mark_document_set_as_synced(
document_set_id=document_set_id, db_session=db_session
)
logger.info(f"Document set sync for '{document_set_id}' complete!")
except Exception:
logger.exception("Failed to sync document set %s", document_set_id)
raise
#####
# Periodic Tasks
#####
@celery_app.task(
name="check_for_document_sets_sync_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_document_sets_sync_task() -> None:
"""Runs periodically to check if any sync tasks should be run and adds them
to the queue"""
with Session(get_sqlalchemy_engine()) as db_session:
# check if any document sets are not synced
document_set_info = fetch_document_sets(
user_id=None, db_session=db_session, include_outdated=True
)
for document_set, _ in document_set_info:
if should_sync_doc_set(document_set, db_session):
logger.info(f"Syncing the {document_set.name} document set")
sync_document_set_task.apply_async(
kwargs=dict(document_set_id=document_set.id),
)
@celery_app.task(
name="check_for_cc_pair_deletion_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_cc_pair_deletion_task() -> None:
"""Runs periodically to check if any deletion tasks should be run"""
with Session(get_sqlalchemy_engine()) as db_session:
# check if any document sets are not synced
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
if should_kick_off_deletion_of_cc_pair(cc_pair, db_session):
logger.info(f"Deleting the {cc_pair.name} connector credential pair")
cleanup_connector_credential_pair_task.apply_async(
kwargs=dict(
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
),
)
@celery_app.task(
name="check_for_prune_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_prune_task() -> None:
"""Runs periodically to check if any prune tasks should be run and adds them
to the queue"""
with Session(get_sqlalchemy_engine()) as db_session:
all_cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in all_cc_pairs:
if should_prune_cc_pair(
connector=cc_pair.connector,
credential=cc_pair.credential,
db_session=db_session,
):
logger.info(f"Pruning the {cc_pair.connector.name} connector")
prune_documents_task.apply_async(
kwargs=dict(
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
)
#####
# Celery Beat (Periodic Tasks) Settings
#####
tenant_ids = get_all_tenant_ids()
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": "check_for_vespa_sync_task",
celery_app.conf.beat_schedule = {
"check-for-document-set-sync": {
"task": "check_for_document_sets_sync_task",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
"check-for-cc-pair-deletion": {
"task": "check_for_cc_pair_deletion_task",
# don't need to check too often, since we kick off a deletion initially
# during the API call that actually marks the CC pair for deletion
"schedule": timedelta(minutes=1),
},
}
celery_app.conf.beat_schedule.update(
{
"name": "check-for-connector-deletion",
"task": "check_for_connector_deletion_task",
"schedule": timedelta(seconds=60),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-indexing",
"task": "check_for_indexing",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": "check_for_pruning",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "kombu-message-cleanup",
"task": "kombu_message_cleanup_task",
"schedule": timedelta(seconds=3600),
"options": {"priority": DanswerCeleryPriority.LOWEST},
},
{
"name": "monitor-vespa-sync",
"task": "monitor_vespa_sync",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
]
# Build the celery beat schedule dynamically
beat_schedule = {}
for id in tenant_ids:
for task in tasks_to_schedule:
task_name = f"{task['name']}-{id}" # Unique name for each scheduled task
beat_schedule[task_name] = {
"task": task["task"],
"schedule": task["schedule"],
"options": task["options"],
"kwargs": {"tenant_id": id}, # Must pass tenant_id as an argument
}
# Include any existing beat schedules
existing_beat_schedule = celery_app.conf.beat_schedule or {}
beat_schedule.update(existing_beat_schedule)
# Update the Celery app configuration once
celery_app.conf.beat_schedule = beat_schedule
"check-for-prune": {
"task": "check_for_prune_task",
"schedule": timedelta(seconds=5),
},
}
)

View File

@@ -1,557 +0,0 @@
# These are helper objects for tracking the keys we need to write in redis
import time
from abc import ABC
from abc import abstractmethod
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.celeryconfig import CELERY_SEPARATOR
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import construct_document_select_for_connector_credential_pair
from danswer.db.document import (
construct_document_select_for_connector_credential_pair_by_needs_sync,
)
from danswer.db.document_set import construct_document_select_by_docset
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import global_version
class RedisObjectHelper(ABC):
PREFIX = "base"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: str):
self._id: str = id
@property
def task_id_prefix(self) -> str:
return f"{self.PREFIX}_{self._id}"
@property
def fence_key(self) -> str:
# example: documentset_fence_1
return f"{self.FENCE_PREFIX}_{self._id}"
@property
def taskset_key(self) -> str:
# example: documentset_taskset_1
return f"{self.TASKSET_PREFIX}_{self._id}"
@staticmethod
def get_id_from_fence_key(key: str) -> str | None:
"""
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
Args:
key (str): The fence key string.
Returns:
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
"""
parts = key.split("_")
if len(parts) != 3:
return None
object_id = parts[2]
return object_id
@staticmethod
def get_id_from_task_id(task_id: str) -> str | None:
"""
Extracts the object ID from a task ID string.
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
- `objectid` is the ID you want to extract,
- `suffix` is another arbitrary string (e.g., a UUID).
Example:
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
this method will return the string `"1"`.
Args:
task_id (str): The task ID string from which to extract the object ID.
Returns:
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
"""
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
parts = task_id.split("_")
if len(parts) != 3:
return None
object_id = parts[1]
return object_id
@abstractmethod
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
pass
class RedisDocumentSet(RedisObjectHelper):
PREFIX = "documentset"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
stmt = construct_document_select_by_docset(int(self._id), current_only=False)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
async_results.append(result)
return len(async_results)
class RedisUserGroup(RedisObjectHelper):
PREFIX = "usergroup"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
if not global_version.is_ee_version():
return 0
try:
construct_document_select_by_usergroup = fetch_versioned_implementation(
"danswer.db.user_group",
"construct_document_select_by_usergroup",
)
except ModuleNotFoundError:
return 0
stmt = construct_document_select_by_usergroup(int(self._id))
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
async_results.append(result)
return len(async_results)
class RedisConnectorCredentialPair(RedisObjectHelper):
"""This class is used to scan documents by cc_pair in the db and collect them into
a unified set for syncing.
It differs from the other redis helpers in that the taskset used spans
all connectors and is not per connector."""
PREFIX = "connectorsync"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
@classmethod
def get_fence_key(cls) -> str:
return RedisConnectorCredentialPair.FENCE_PREFIX
@classmethod
def get_taskset_key(cls) -> str:
return RedisConnectorCredentialPair.TASKSET_PREFIX
@property
def taskset_key(self) -> str:
"""Notice that this is intentionally reusing the same taskset for all
connector syncs"""
# example: connector_taskset
return f"{self.TASKSET_PREFIX}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
class RedisConnectorDeletion(RedisObjectHelper):
PREFIX = "connectordeletion"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
stmt = construct_document_select_for_connector_credential_pair(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(self.taskset_key, custom_task_id)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"document_by_cc_pair_cleanup_task",
kwargs=dict(
document_id=doc.id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
class RedisConnectorPruning(RedisObjectHelper):
"""Celery will kick off a long running generator task to crawl the connector and
find any missing docs, which will each then get a new cleanup task. The progress of
those tasks will then be monitored to completion.
Example rough happy path order:
Check connectorpruning_fence_1
Send generator task with id connectorpruning+generator_1_{uuid}
generator runs connector with callbacks that increment connectorpruning_generator_progress_1
generator creates many subtasks with id connectorpruning+sub_1_{uuid}
in taskset connectorpruning_taskset_1
on completion, generator sets connectorpruning_generator_complete_1
celery postrun removes subtasks from taskset
monitor beat task cleans up when taskset reaches 0 items
"""
PREFIX = "connectorpruning"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire pruning process
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
SUBTASK_PREFIX = PREFIX + "+sub"
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # a signal that contains generator progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # a signal that the generator has finished
def __init__(self, id: int) -> None:
super().__init__(str(id))
self.documents_to_prune: set[str] = set()
@property
def generator_task_id_prefix(self) -> str:
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
@property
def generator_progress_key(self) -> str:
# example: connectorpruning_generator_progress_1
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
@property
def generator_complete_key(self) -> str:
# example: connectorpruning_generator_complete_1
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
@property
def subtask_id_prefix(self) -> str:
return f"{self.SUBTASK_PREFIX}_{self._id}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
for doc_id in self.documents_to_prune:
current_time = time.monotonic()
if lock and current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.subtask_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(self.taskset_key, custom_task_id)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"document_by_cc_pair_cleanup_task",
kwargs=dict(
document_id=doc_id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
def is_pruning(self, db_session: Session, redis_client: Redis) -> bool:
"""A single example of a helper method being refactored into the redis helper"""
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=int(self._id), db_session=db_session
)
if not cc_pair:
raise ValueError(f"cc_pair_id {self._id} does not exist.")
if redis_client.exists(self.fence_key):
return True
return False
class RedisConnectorIndexing(RedisObjectHelper):
"""Celery will kick off a long running indexing task to crawl the connector and
find any new or updated docs docs, which will each then get a new sync task or be
indexed inline.
ID should be a concatenation of cc_pair_id and search_setting_id, delimited by "/".
e.g. "2/5"
"""
PREFIX = "connectorindexing"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
SUBTASK_PREFIX = PREFIX + "+sub"
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # a signal that contains generator progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # a signal that the generator has finished
def __init__(self, cc_pair_id: int, search_settings_id: int) -> None:
super().__init__(f"{cc_pair_id}/{search_settings_id}")
@property
def generator_lock_key(self) -> str:
return f"{self.GENERATOR_LOCK_PREFIX}_{self._id}"
@property
def generator_task_id_prefix(self) -> str:
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
@property
def generator_progress_key(self) -> str:
# example: connectorpruning_generator_progress_1
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
@property
def generator_complete_key(self) -> str:
# example: connectorpruning_generator_complete_1
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
@property
def subtask_id_prefix(self) -> str:
return f"{self.SUBTASK_PREFIX}_{self._id}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
return None
def celery_get_queue_length(queue: str, r: Redis) -> int:
"""This is a redis specific way to get the length of a celery queue.
It is priority aware and knows how to count across the multiple redis lists
used to implement task prioritization.
This operation is not atomic."""
total_length = 0
for i in range(len(DanswerCeleryPriority)):
queue_name = queue
if i > 0:
queue_name += CELERY_SEPARATOR
queue_name += str(i)
length = r.llen(queue_name)
total_length += cast(int, length)
return total_length

View File

@@ -1,15 +1,13 @@
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from typing import Any
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.constants import TENANT_ID_PREFIX
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
@@ -18,53 +16,36 @@ from danswer.connectors.interfaces import IdConnector
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import Document
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import TaskStatus
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.engine import get_db_current_time
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import Connector
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential
from danswer.db.models import DocumentSet
from danswer.db.models import TaskQueueState
from danswer.redis.redis_pool import get_redis_client
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.db.tasks import get_latest_task_by_type
from danswer.server.documents.models import DeletionAttemptSnapshot
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _get_deletion_status(
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str | None = None,
def get_deletion_status(
connector_id: int, credential_id: int, db_session: Session
) -> TaskQueueState | None:
"""We no longer store TaskQueueState in the DB for a deletion attempt.
This function populates TaskQueueState by just checking redis.
"""
cc_pair = get_connector_credential_pair(
connector_id=connector_id, credential_id=credential_id, db_session=db_session
)
if not cc_pair:
return None
rcd = RedisConnectorDeletion(cc_pair.id)
r = get_redis_client(tenant_id=tenant_id)
if not r.exists(rcd.fence_key):
return None
return TaskQueueState(
task_id="", task_name=rcd.fence_key, status=TaskStatus.STARTED
cleanup_task_name = name_cc_cleanup_task(
connector_id=connector_id, credential_id=credential_id
)
return get_latest_task(task_name=cleanup_task_name, db_session=db_session)
def get_deletion_attempt_snapshot(
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str | None = None,
connector_id: int, credential_id: int, db_session: Session
) -> DeletionAttemptSnapshot | None:
deletion_task = _get_deletion_status(
connector_id, credential_id, db_session, tenant_id
)
deletion_task = get_deletion_status(connector_id, credential_id, db_session)
if not deletion_task:
return None
@@ -75,19 +56,93 @@ def get_deletion_attempt_snapshot(
)
def should_kick_off_deletion_of_cc_pair(
cc_pair: ConnectorCredentialPair, db_session: Session
) -> bool:
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
return False
if check_deletion_attempt_is_allowed(cc_pair, db_session):
return False
deletion_task = get_deletion_status(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
db_session=db_session,
)
if deletion_task and check_task_is_live_and_not_timed_out(
deletion_task,
db_session,
# 1 hour timeout
timeout=60 * 60,
):
return False
return True
def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool:
if document_set.is_up_to_date:
return False
task_name = name_document_set_sync_task(document_set.id)
latest_sync = get_latest_task(task_name, db_session)
if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session):
logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.")
return False
logger.info(f"Document set {document_set.id} syncing now!")
return True
def should_prune_cc_pair(
connector: Connector, credential: Credential, db_session: Session
) -> bool:
if not connector.prune_freq:
return False
pruning_task_name = name_cc_prune_task(
connector_id=connector.id, credential_id=credential.id
)
last_pruning_task = get_latest_task(pruning_task_name, db_session)
current_db_time = get_db_current_time(db_session)
if not last_pruning_task:
time_since_initialization = current_db_time - connector.time_created
if time_since_initialization.total_seconds() >= connector.prune_freq:
return True
return False
if not ALLOW_SIMULTANEOUS_PRUNING:
pruning_type_task_name = name_cc_prune_task()
last_pruning_type_task = get_latest_task_by_type(
pruning_type_task_name, db_session
)
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
last_pruning_type_task, db_session
):
return False
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
return False
if not last_pruning_task.start_time:
return False
time_since_last_pruning = current_db_time - last_pruning_task.start_time
return time_since_last_pruning.total_seconds() >= connector.prune_freq
def document_batch_to_ids(doc_batch: list[Document]) -> set[str]:
return {doc.id for doc in doc_batch}
def extract_ids_from_runnable_connector(
runnable_connector: BaseConnector,
progress_callback: Callable[[int], None] | None = None,
) -> set[str]:
def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> set[str]:
"""
If the PruneConnector hasnt been implemented for the given connector, just pull
all docs using the load_from_state and grab out the IDs.
Optionally, a callback can be passed to handle the length of each document batch.
all docs using the load_from_state and grab out the IDs
"""
all_connector_doc_ids: set[str] = set()
@@ -110,56 +165,6 @@ def extract_ids_from_runnable_connector(
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
)(document_batch_to_ids)
for doc_batch in doc_batch_generator:
if progress_callback:
progress_callback(len(doc_batch))
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
return all_connector_doc_ids
def celery_is_listening_to_queue(worker: Any, name: str) -> bool:
"""Checks to see if we're listening to the named queue"""
# how to get a list of queues this worker is listening to
# https://stackoverflow.com/questions/29790523/how-to-determine-which-queues-a-celery-worker-is-consuming-at-runtime
queue_names = list(worker.app.amqp.queues.consume_from.keys())
for queue_name in queue_names:
if queue_name == name:
return True
return False
def celery_is_worker_primary(worker: Any) -> bool:
"""There are multiple approaches that could be taken to determine if a celery worker
is 'primary', as defined by us. But the way we do it is to check the hostname set
for the celery worker, which can be done either in celeryconfig.py or on the
command line with '--hostname'."""
hostname = worker.hostname
if hostname.startswith("primary"):
return True
return False
def get_all_tenant_ids() -> list[str] | list[None]:
if not MULTI_TENANT:
return [None]
with get_session_with_tenant(tenant_id="public") as session:
result = session.execute(
text(
"""
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')"""
)
)
tenant_ids = [row[0] for row in result]
valid_tenants = [
tenant
for tenant in tenant_ids
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
]
return valid_tenants

View File

@@ -1,104 +0,0 @@
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
import urllib.parse
from danswer.configs.app_configs import CELERY_BROKER_POOL_LIMIT
from danswer.configs.app_configs import CELERY_RESULT_EXPIRES
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY_RESULT_BACKEND
from danswer.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
from danswer.configs.app_configs import REDIS_HOST
from danswer.configs.app_configs import REDIS_PASSWORD
from danswer.configs.app_configs import REDIS_PORT
from danswer.configs.app_configs import REDIS_SSL
from danswer.configs.app_configs import REDIS_SSL_CA_CERTS
from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
CELERY_SEPARATOR = ":"
CELERY_PASSWORD_PART = ""
if REDIS_PASSWORD:
CELERY_PASSWORD_PART = ":" + urllib.parse.quote(REDIS_PASSWORD, safe="") + "@"
REDIS_SCHEME = "redis"
# SSL-specific query parameters for Redis URL
SSL_QUERY_PARAMS = ""
if REDIS_SSL:
REDIS_SCHEME = "rediss"
SSL_QUERY_PARAMS = f"?ssl_cert_reqs={REDIS_SSL_CERT_REQS}"
if REDIS_SSL_CA_CERTS:
SSL_QUERY_PARAMS += f"&ssl_ca_certs={REDIS_SSL_CA_CERTS}"
# example celery_broker_url: "redis://:password@localhost:6379/15"
broker_url = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}"
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY_RESULT_BACKEND}{SSL_QUERY_PARAMS}"
# NOTE: prefetch 4 is significantly faster than prefetch 1 for small tasks
# however, prefetching is bad when tasks are lengthy as those tasks
# can stall other tasks.
worker_prefetch_multiplier = 4
# Leaving this to the default of True may cause double logging since both our own app
# and celery think they are controlling the logger.
# TODO: Configure celery's logger entirely manually and set this to False
# worker_hijack_root_logger = False
broker_connection_retry_on_startup = True
broker_pool_limit = CELERY_BROKER_POOL_LIMIT
# redis broker settings
# https://docs.celeryq.dev/projects/kombu/en/stable/reference/kombu.transport.redis.html
broker_transport_options = {
"priority_steps": list(range(len(DanswerCeleryPriority))),
"sep": CELERY_SEPARATOR,
"queue_order_strategy": "priority",
"retry_on_timeout": True,
"health_check_interval": REDIS_HEALTH_CHECK_INTERVAL,
"socket_keepalive": True,
"socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS,
}
# redis backend settings
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
# there doesn't appear to be a way to set socket_keepalive_options on the redis result backend
redis_socket_keepalive = True
redis_retry_on_timeout = True
redis_backend_health_check_interval = REDIS_HEALTH_CHECK_INTERVAL
task_default_priority = DanswerCeleryPriority.MEDIUM
task_acks_late = True
# It's possible we don't even need celery's result backend, in which case all of the optimization below
# might be irrelevant
result_expires = CELERY_RESULT_EXPIRES # 86400 seconds is the default
# Option 0: Defaults (json serializer, no compression)
# about 1.5 KB per queued task. 1KB in queue, 400B for result, 100 as a child entry in generator result
# Option 1: Reduces generator task result sizes by roughly 20%
# task_compression = "bzip2"
# task_serializer = "pickle"
# result_compression = "bzip2"
# result_serializer = "pickle"
# accept_content=["pickle"]
# Option 2: this significantly reduces the size of the result for generator tasks since the list of children
# can be large. small tasks change very little
# def pickle_bz2_encoder(data):
# return bz2.compress(pickle.dumps(data))
# def pickle_bz2_decoder(data):
# return pickle.loads(bz2.decompress(data))
# from kombu import serialization # To register custom serialization with Celery/Kombu
# serialization.register('pickle-bzip2', pickle_bz2_encoder, pickle_bz2_decoder, 'application/x-pickle-bz2', 'binary')
# task_serializer = "pickle-bzip2"
# result_serializer = "pickle-bzip2"
# accept_content=["pickle", "pickle-bzip2"]

View File

@@ -1,113 +0,0 @@
import redis
from celery import shared_task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import ObjectDeletedError
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_pool import get_redis_client
@shared_task(
name="check_for_connector_deletion_task",
soft_time_limit=JOB_TIMEOUT,
trail=False,
)
def check_for_connector_deletion_task(*, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
try_generate_document_cc_pair_cleanup_tasks(
cc_pair, db_session, r, lock_beat, tenant_id
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception")
finally:
if lock_beat.owned():
lock_beat.release()
def try_generate_document_cc_pair_cleanup_tasks(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Note that syncing can still be required even if the number of sync tasks generated is zero.
Returns None if no syncing is required.
"""
lock_beat.reacquire()
rcd = RedisConnectorDeletion(cc_pair.id)
# don't generate sync tasks if tasks are still pending
if r.exists(rcd.fence_key):
return None
# we need to refresh the state of the object inside the fence
# to avoid a race condition with db.commit/fence deletion
# at the end of this taskset
try:
db_session.refresh(cc_pair)
except ObjectDeletedError:
return None
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
return None
# add tasks to celery and build up the task set to monitor in redis
r.delete(rcd.taskset_key)
# Add all documents that need to be updated into the queue
task_logger.info(
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
)
tasks_generated = rcd.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
return None
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
task_logger.info(
f"RedisConnectorDeletion.generate_tasks finished. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
r.set(rcd.fence_key, tasks_generated)
return tasks_generated

View File

@@ -1,455 +0,0 @@
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from time import sleep
from typing import cast
from uuid import uuid4
from celery import shared_task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.tasks.shared.tasks import RedisConnectorIndexingFenceData
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import DocumentSource
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import IndexingStatus
from danswer.db.enums import IndexModelStatus
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
logger = setup_logger()
@shared_task(
name="check_for_indexing",
soft_time_limit=300,
)
def check_for_indexing(*, tenant_id: str | None) -> int | None:
tasks_created = 0
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
task_logger.info(f"Lock acquired for tenant (Y): {tenant_id}")
return None
else:
task_logger.info(f"Lock acquired for tenant (N): {tenant_id}")
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings = [primary_search_settings]
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings.append(secondary_search_settings)
cc_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
for search_settings_instance in search_settings:
rci = RedisConnectorIndexing(
cc_pair.id, search_settings_instance.id
)
if r.exists(rci.fence_key):
continue
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, search_settings_instance.id, db_session
)
if not _should_index(
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
secondary_index_building=len(search_settings) > 1,
db_session=db_session,
):
continue
# using a task queue and only allowing one task per cc_pair/search_setting
# prevents us from starving out certain attempts
attempt_id = try_creating_indexing_task(
cc_pair,
search_settings_instance,
False,
db_session,
r,
tenant_id,
)
if attempt_id:
task_logger.info(
f"Indexing queued: cc_pair_id={cc_pair.id} index_attempt_id={attempt_id}"
)
tasks_created += 1
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception")
finally:
if lock_beat.owned():
lock_beat.release()
return tasks_created
def _should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
secondary_index_building: bool,
db_session: Session,
) -> bool:
"""Checks various global settings and past indexing attempts to determine if
we should try to start indexing the cc pair / search setting combination.
Note that tactical checks such as preventing overlap with a currently running task
are not handled here.
Return True if we should try to index, False if not.
"""
connector = cc_pair.connector
# uncomment for debugging
# task_logger.info(f"_should_index: "
# f"cc_pair={cc_pair.id} "
# f"connector={cc_pair.connector_id} "
# f"refresh_freq={connector.refresh_freq}")
# don't kick off indexing for `NOT_APPLICABLE` sources
if connector.source == DocumentSource.NOT_APPLICABLE:
return False
# User can still manually create single indexing attempts via the UI for the
# currently in use index
if DISABLE_INDEX_UPDATE_ON_SWAP:
if (
search_settings_instance.status == IndexModelStatus.PRESENT
and secondary_index_building
):
return False
# When switching over models, always index at least once
if search_settings_instance.status == IndexModelStatus.FUTURE:
if last_index:
# No new index if the last index attempt succeeded
# Once is enough. The model will never be able to swap otherwise.
if last_index.status == IndexingStatus.SUCCESS:
return False
# No new index if the last index attempt is waiting to start
if last_index.status == IndexingStatus.NOT_STARTED:
return False
# No new index if the last index attempt is running
if last_index.status == IndexingStatus.IN_PROGRESS:
return False
else:
if (
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
): # Ingestion API
return False
return True
# If the connector is paused or is the ingestion API, don't index
# NOTE: during an embedding model switch over, the following logic
# is bypassed by the above check for a future model
if (
not cc_pair.status.is_active()
or connector.id == 0
or connector.source == DocumentSource.INGESTION_API
):
return False
# if no attempt has ever occurred, we should index regardless of refresh_freq
if not last_index:
return True
if connector.refresh_freq is None:
return False
current_db_time = get_db_current_time(db_session)
time_since_index = current_db_time - last_index.time_updated
if time_since_index.total_seconds() < connector.refresh_freq:
return False
return True
def try_creating_indexing_task(
cc_pair: ConnectorCredentialPair,
search_settings: SearchSettings,
reindex: bool,
db_session: Session,
r: Redis,
tenant_id: str | None,
) -> int | None:
"""Checks for any conditions that should block the indexing task from being
created, then creates the task.
Does not check for scheduling related conditions as this function
is used to trigger indexing immediately.
"""
LOCK_TIMEOUT = 30
# we need to serialize any attempt to trigger indexing since it can be triggered
# either via celery beat or manually (API call)
lock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
try:
rci = RedisConnectorIndexing(cc_pair.id, search_settings.id)
# skip if already indexing
if r.exists(rci.fence_key):
return None
# skip indexing if the cc_pair is deleting
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# add a long running generator task to the queue
r.delete(rci.generator_complete_key)
r.delete(rci.taskset_key)
custom_task_id = f"{rci.generator_task_id_prefix}_{uuid4()}"
# create the index attempt ... just for tracking purposes
index_attempt_id = create_index_attempt(
cc_pair.id,
search_settings.id,
from_beginning=reindex,
db_session=db_session,
)
result = celery_app.send_task(
"connector_indexing_proxy_task",
kwargs=dict(
index_attempt_id=index_attempt_id,
cc_pair_id=cc_pair.id,
search_settings_id=search_settings.id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_INDEXING,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
if not result:
return None
# set this only after all tasks have been added
fence_value = RedisConnectorIndexingFenceData(
index_attempt_id=index_attempt_id,
started=None,
submitted=datetime.now(timezone.utc),
celery_task_id=result.id,
)
r.set(rci.fence_key, fence_value.model_dump_json())
except Exception:
task_logger.exception("Unexpected exception")
return None
finally:
if lock.owned():
lock.release()
return index_attempt_id
@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True)
def connector_indexing_proxy_task(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
) -> None:
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
client = SimpleJobClient()
job = client.submit(
connector_indexing_task,
index_attempt_id,
cc_pair_id,
search_settings_id,
tenant_id,
global_version.is_ee_version(),
pure=False,
)
if not job:
return
while True:
sleep(10)
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
# do nothing for ongoing jobs that haven't been stopped
if not job.done():
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
if job.status == "error":
logger.error(job.exception())
job.release()
break
return
def connector_indexing_task(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
is_ee: bool,
) -> int | None:
"""Indexing task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list
acks_late must be set to False. Otherwise, celery's visibility timeout will
cause any task that runs longer than the timeout to be redispatched by the broker.
There appears to be no good workaround for this, so we need to handle redispatching
manually.
Returns None if the task did not run (possibly due to a conflict).
Otherwise, returns an int >= 0 representing the number of indexed docs.
"""
attempt = None
n_final_progress = 0
r = get_redis_client(tenant_id=tenant_id)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
lock = r.lock(
rci.generator_lock_key,
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"Indexing task already running, exiting...: "
f"cc_pair_id={cc_pair_id} search_settings_id={search_settings_id}"
)
# r.set(rci.generator_complete_key, HTTPStatus.CONFLICT.value)
return None
try:
with get_session_with_tenant(tenant_id) as db_session:
attempt = get_index_attempt(db_session, index_attempt_id)
if not attempt:
raise ValueError(
f"Index attempt not found: index_attempt_id={index_attempt_id}"
)
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
)
if not cc_pair:
raise ValueError(f"cc_pair not found: cc_pair_id={cc_pair_id}")
if not cc_pair.connector:
raise ValueError(
f"Connector not found: connector_id={cc_pair.connector_id}"
)
if not cc_pair.credential:
raise ValueError(
f"Credential not found: credential_id={cc_pair.credential_id}"
)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
# Define the callback function
def redis_increment_callback(amount: int) -> None:
lock.reacquire()
r.incrby(rci.generator_progress_key, amount)
run_indexing_entrypoint(
index_attempt_id,
tenant_id,
cc_pair_id,
is_ee,
progress_callback=redis_increment_callback,
)
# get back the total number of indexed docs and return it
generator_progress_value = r.get(rci.generator_progress_key)
if generator_progress_value is not None:
try:
n_final_progress = int(cast(int, generator_progress_value))
except ValueError:
pass
r.set(rci.generator_complete_key, HTTPStatus.OK.value)
except Exception as e:
task_logger.exception(f"Failed to run indexing for cc_pair_id={cc_pair_id}.")
if attempt:
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
r.delete(rci.generator_lock_key)
r.delete(rci.generator_progress_key)
r.delete(rci.taskset_key)
r.delete(rci.fence_key)
raise e
finally:
if lock.owned():
lock.release()
return n_final_progress

View File

@@ -1,137 +0,0 @@
#####
# Periodic Tasks
#####
import json
from typing import Any
from celery import shared_task
from celery.contrib.abortable import AbortableTask # type: ignore
from celery.exceptions import TaskRevokedError
from sqlalchemy import inspect
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import PostgresAdvisoryLocks
from danswer.db.engine import get_session_with_tenant
@shared_task(
name="kombu_message_cleanup_task",
soft_time_limit=JOB_TIMEOUT,
bind=True,
base=AbortableTask,
)
def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
"""Runs periodically to clean up the kombu_message table"""
# we will select messages older than this amount to clean up
KOMBU_MESSAGE_CLEANUP_AGE = 7 # days
KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT = 1000
ctx = {}
ctx["last_processed_id"] = 0
ctx["deleted"] = 0
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
with get_session_with_tenant(tenant_id) as db_session:
# Exit the task if we can't take the advisory lock
result = db_session.execute(
text("SELECT pg_try_advisory_lock(:id)"),
{"id": PostgresAdvisoryLocks.KOMBU_MESSAGE_CLEANUP_LOCK_ID.value},
).scalar()
if not result:
return 0
while True:
if self.is_aborted():
raise TaskRevokedError("kombu_message_cleanup_task was aborted.")
b = kombu_message_cleanup_task_helper(ctx, db_session)
if not b:
break
db_session.commit()
if ctx["deleted"] > 0:
task_logger.info(
f"Deleted {ctx['deleted']} orphaned messages from kombu_message."
)
return ctx["deleted"]
def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool:
"""
Helper function to clean up old messages from the `kombu_message` table that are no longer relevant.
This function retrieves messages from the `kombu_message` table that are no longer visible and
older than a specified interval. It checks if the corresponding task_id exists in the
`celery_taskmeta` table. If the task_id does not exist, the message is deleted.
Args:
ctx (dict): A context dictionary containing configuration parameters such as:
- 'cleanup_age' (int): The age in days after which messages are considered old.
- 'page_limit' (int): The maximum number of messages to process in one batch.
- 'last_processed_id' (int): The ID of the last processed message to handle pagination.
- 'deleted' (int): A counter to track the number of deleted messages.
db_session (Session): The SQLAlchemy database session for executing queries.
Returns:
bool: Returns True if there are more rows to process, False if not.
"""
inspector = inspect(db_session.bind)
if not inspector:
return False
# With the move to redis as celery's broker and backend, kombu tables may not even exist.
# We can fail silently.
if not inspector.has_table("kombu_message"):
return False
query = text(
"""
SELECT id, timestamp, payload
FROM kombu_message WHERE visible = 'false'
AND timestamp < CURRENT_TIMESTAMP - INTERVAL :interval_days
AND id > :last_processed_id
ORDER BY id
LIMIT :page_limit
"""
)
kombu_messages = db_session.execute(
query,
{
"interval_days": f"{ctx['cleanup_age']} days",
"page_limit": ctx["page_limit"],
"last_processed_id": ctx["last_processed_id"],
},
).fetchall()
if len(kombu_messages) == 0:
return False
for msg in kombu_messages:
payload = json.loads(msg[2])
task_id = payload["headers"]["id"]
# Check if task_id exists in celery_taskmeta
task_exists = db_session.execute(
text("SELECT 1 FROM celery_taskmeta WHERE task_id = :task_id"),
{"task_id": task_id},
).fetchone()
# If task_id does not exist, delete the message
if not task_exists:
result = db_session.execute(
text("DELETE FROM kombu_message WHERE id = :message_id"),
{"message_id": msg[0]},
)
if result.rowcount > 0: # type: ignore
ctx["deleted"] += 1
ctx["last_processed_id"] = msg[0]
return True

View File

@@ -1,301 +0,0 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from uuid import uuid4
from celery import shared_task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import InputType
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
logger = setup_logger()
@shared_task(
name="check_for_pruning",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_pruning(*, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_PRUNE_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
lock_beat.reacquire()
if not is_pruning_due(cc_pair, db_session, r):
continue
tasks_created = try_creating_prune_generator_task(
cc_pair, db_session, r, tenant_id
)
if not tasks_created:
continue
task_logger.info(f"Pruning queued: cc_pair_id={cc_pair.id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception")
finally:
if lock_beat.owned():
lock_beat.release()
def is_pruning_due(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
) -> bool:
"""Returns an int if pruning is triggered.
The int represents the number of prune tasks generated (in this case, only one
because the task is a long running generator task.)
Returns None if no pruning is triggered (due to not being needed or
other reasons such as simultaneous pruning restrictions.
Checks for scheduling related conditions, then delegates the rest of the checks to
try_creating_prune_generator_task.
"""
# skip pruning if no prune frequency is set
# pruning can still be forced via the API which will run a pruning task directly
if not cc_pair.connector.prune_freq:
return False
# skip pruning if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
# skip pruning if the next scheduled prune time hasn't been reached yet
last_pruned = cc_pair.last_pruned
if not last_pruned:
if not cc_pair.last_successful_index_time:
# if we've never indexed, we can't prune
return False
# if never pruned, use the last time the connector indexed successfully
last_pruned = cc_pair.last_successful_index_time
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
if datetime.now(timezone.utc) < next_prune:
return False
return True
def try_creating_prune_generator_task(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
tenant_id: str | None,
) -> int | None:
"""Checks for any conditions that should block the pruning generator task from being
created, then creates the task.
Does not check for scheduling related conditions as this function
is used to trigger prunes immediately, e.g. via the web ui.
"""
if not ALLOW_SIMULTANEOUS_PRUNING:
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
return None
LOCK_TIMEOUT = 30
# we need to serialize starting pruning since it can be triggered either via
# celery beat or manually (API call)
lock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_prune_generator_task",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
try:
rcp = RedisConnectorPruning(cc_pair.id)
# skip pruning if already pruning
if r.exists(rcp.fence_key):
return None
# skip pruning if the cc_pair is deleting
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# add a long running generator task to the queue
r.delete(rcp.generator_complete_key)
r.delete(rcp.taskset_key)
custom_task_id = f"{rcp.generator_task_id_prefix}_{uuid4()}"
celery_app.send_task(
"connector_pruning_generator_task",
kwargs=dict(
cc_pair_id=cc_pair.id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_PRUNING,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
# set this only after all tasks have been added
r.set(rcp.fence_key, 1)
except Exception:
task_logger.exception("Unexpected exception")
return None
finally:
if lock.owned():
lock.release()
return 1
@shared_task(
name="connector_pruning_generator_task",
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,
trail=False,
)
def connector_pruning_generator_task(
cc_pair_id: int, connector_id: int, credential_id: int, tenant_id: str | None
) -> None:
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
r = get_redis_client(tenant_id=tenant_id)
rcp = RedisConnectorPruning(cc_pair_id)
lock = r.lock(
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{rcp._id}",
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"Pruning task already running, exiting...: cc_pair_id={cc_pair_id}"
)
return None
try:
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair:
task_logger.warning(
f"cc_pair not found for {connector_id} {credential_id}"
)
return
# Define the callback function
def redis_increment_callback(amount: int) -> None:
lock.reacquire()
r.incrby(rcp.generator_progress_key, amount)
runnable_connector = instantiate_connector(
db_session,
cc_pair.connector.source,
InputType.PRUNE,
cc_pair.connector.connector_specific_config,
cc_pair.credential,
)
# a list of docs in the source
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
runnable_connector, redis_increment_callback
)
# a list of docs in our local index
all_indexed_document_ids = {
doc.id
for doc in get_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
}
# generate list of docs to remove (no longer in the source)
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
task_logger.info(
f"Pruning set collected: "
f"cc_pair_id={cc_pair.id} "
f"docs_to_remove={len(doc_ids_to_remove)} "
f"doc_source={cc_pair.connector.source}"
)
rcp.documents_to_prune = set(doc_ids_to_remove)
task_logger.info(
f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}"
)
tasks_generated = rcp.generate_tasks(
celery_app, db_session, r, None, tenant_id
)
if tasks_generated is None:
return None
task_logger.info(
f"RedisConnectorPruning.generate_tasks finished. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
)
r.set(rcp.generator_complete_key, tasks_generated)
except Exception as e:
task_logger.exception(f"Failed to run pruning for connector id {connector_id}.")
r.delete(rcp.generator_progress_key)
r.delete(rcp.taskset_key)
r.delete(rcp.fence_key)
raise e
finally:
if lock.owned():
lock.release()

View File

@@ -1,144 +0,0 @@
from datetime import datetime
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from pydantic import BaseModel
from danswer.access.access import get_access_for_document
from danswer.background.celery.celery_app import task_logger
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document
from danswer.db.document import get_document_connector_count
from danswer.db.document import mark_document_as_synced
from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.engine import get_session_with_tenant
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
class RedisConnectorIndexingFenceData(BaseModel):
index_attempt_id: int
started: datetime | None
submitted: datetime
celery_task_id: str
@shared_task(
name="document_by_cc_pair_cleanup_task",
bind=True,
soft_time_limit=45,
time_limit=60,
max_retries=3,
)
def document_by_cc_pair_cleanup_task(
self: Task,
document_id: str,
connector_id: int,
credential_id: int,
tenant_id: str | None,
) -> bool:
"""A lightweight subtask used to clean up document to cc pair relationships.
Created by connection deletion and connector pruning parent tasks."""
"""
To delete a connector / credential pair:
(1) find all documents associated with connector / credential pair where there
this the is only connector / credential pair that has indexed it
(2) delete all documents from document stores
(3) delete all entries from postgres
(4) find all documents associated with connector / credential pair where there
are multiple connector / credential pairs that have indexed it
(5) update document store entries to remove access associated with the
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
task_logger.info(f"document_id={document_id}")
try:
with get_session_with_tenant(tenant_id) as db_session:
action = "skip"
chunks_affected = 0
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
count = get_document_connector_count(db_session, document_id)
if count == 1:
# count == 1 means this is the only remaining cc_pair reference to the doc
# delete it from vespa and the db
action = "delete"
chunks_affected = document_index.delete_single(document_id)
delete_documents_complete__no_commit(
db_session=db_session,
document_ids=[document_id],
)
elif count > 1:
action = "update"
# count > 1 means the document still has cc_pair references
doc = get_document(document_id, db_session)
if not doc:
return False
# the below functions do not include cc_pairs being deleted.
# i.e. they will correctly omit access for the current cc_pair
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
doc_sets = fetch_document_sets_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
fields = VespaDocumentFields(
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
chunks_affected = document_index.update_single(
document_id, fields=fields
)
# there are still other cc_pair references to the doc, so just resync to Vespa
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_id=document_id,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
mark_document_as_synced(document_id, db_session)
else:
pass
task_logger.info(
f"tenant_id={tenant_id} "
f"document_id={document_id} "
f"action={action} "
f"refcount={count} "
f"chunks={chunks_affected}"
)
db_session.commit()
except SoftTimeLimitExceeded:
task_logger.info(
f"SoftTimeLimitExceeded exception. tenant_id={tenant_id} doc_id={document_id}"
)
except Exception as e:
task_logger.exception("Unexpected exception")
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown)
return True

View File

@@ -1,806 +0,0 @@
import traceback
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from typing import cast
import redis
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from celery.result import AsyncResult
from celery.states import READY_STATES
from redis import Redis
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_document
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.celery_redis import celery_get_queue_length
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.tasks.shared.tasks import RedisConnectorIndexingFenceData
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector import mark_ccpair_as_pruned
from danswer.db.connector_credential_pair import add_deletion_failure_message
from danswer.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit,
)
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.document import count_documents_by_needs_sync
from danswer.db.document import get_document
from danswer.db.document import get_document_ids_for_connector_credential_pair
from danswer.db.document import mark_document_as_synced
from danswer.db.document_set import delete_document_set
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from danswer.db.document_set import fetch_document_sets
from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import IndexingStatus
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import DocumentSet
from danswer.db.models import IndexAttempt
from danswer.db.models import UserGroup
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import noop_fallback
logger = setup_logger()
# celery auto associates tasks created inside another task,
# which bloats the result metadata considerably. trail=False prevents this.
@shared_task(
name="check_for_vespa_sync_task",
soft_time_limit=JOB_TIMEOUT,
trail=False,
)
def check_for_vespa_sync_task(*, tenant_id: str | None) -> None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
with get_session_with_tenant(tenant_id) as db_session:
try_generate_stale_document_sync_tasks(db_session, r, lock_beat, tenant_id)
# check if any document sets are not synced
document_set_info = fetch_document_sets(
user_id=None, db_session=db_session, include_outdated=True
)
for document_set, _ in document_set_info:
try_generate_document_set_sync_tasks(
document_set, db_session, r, lock_beat, tenant_id
)
# check if any user groups are not synced
if global_version.is_ee_version():
try:
fetch_user_groups = fetch_versioned_implementation(
"danswer.db.user_group", "fetch_user_groups"
)
user_groups = fetch_user_groups(
db_session=db_session, only_up_to_date=False
)
for usergroup in user_groups:
try_generate_user_group_sync_tasks(
usergroup, db_session, r, lock_beat, tenant_id
)
except ModuleNotFoundError:
# Always exceptions on the MIT version, which is expected
# We shouldn't actually get here if the ee version check works
pass
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception")
finally:
if lock_beat.owned():
lock_beat.release()
def try_generate_stale_document_sync_tasks(
db_session: Session, r: Redis, lock_beat: redis.lock.Lock, tenant_id: str | None
) -> int | None:
# the fence is up, do nothing
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
return None
r.delete(RedisConnectorCredentialPair.get_taskset_key()) # delete the taskset
# add tasks to celery and build up the task set to monitor in redis
stale_doc_count = count_documents_by_needs_sync(db_session)
if stale_doc_count == 0:
return None
task_logger.info(
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
)
task_logger.info("RedisConnector.generate_tasks starting by cc_pair.")
# rkuo: we could technically sync all stale docs in one big pass.
# but I feel it's more understandable to group the docs by cc_pair
total_tasks_generated = 0
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
rc = RedisConnectorCredentialPair(cc_pair.id)
tasks_generated = rc.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
continue
if tasks_generated == 0:
continue
task_logger.info(
f"RedisConnector.generate_tasks finished for single cc_pair. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
)
total_tasks_generated += tasks_generated
task_logger.info(
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
)
r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated)
return total_tasks_generated
def try_generate_document_set_sync_tasks(
document_set: DocumentSet,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
lock_beat.reacquire()
rds = RedisDocumentSet(document_set.id)
# don't generate document set sync tasks if tasks are still pending
if r.exists(rds.fence_key):
return None
# don't generate sync tasks if we're up to date
# race condition with the monitor/cleanup function if we use a cached result!
db_session.refresh(document_set)
if document_set.is_up_to_date:
return None
# add tasks to celery and build up the task set to monitor in redis
r.delete(rds.taskset_key)
task_logger.info(
f"RedisDocumentSet.generate_tasks starting. document_set_id={document_set.id}"
)
# Add all documents that need to be updated into the queue
tasks_generated = rds.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
return None
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
task_logger.info(
f"RedisDocumentSet.generate_tasks finished. "
f"document_set_id={document_set.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
r.set(rds.fence_key, tasks_generated)
return tasks_generated
def try_generate_user_group_sync_tasks(
usergroup: UserGroup,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
lock_beat.reacquire()
rug = RedisUserGroup(usergroup.id)
# don't generate sync tasks if tasks are still pending
if r.exists(rug.fence_key):
return None
# race condition with the monitor/cleanup function if we use a cached result!
db_session.refresh(usergroup)
if usergroup.is_up_to_date:
return None
# add tasks to celery and build up the task set to monitor in redis
r.delete(rug.taskset_key)
# Add all documents that need to be updated into the queue
task_logger.info(
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
)
tasks_generated = rug.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
return None
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
task_logger.info(
f"RedisUserGroup.generate_tasks finished. "
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
r.set(rug.fence_key, tasks_generated)
return tasks_generated
def monitor_connector_taskset(r: Redis) -> None:
fence_value = r.get(RedisConnectorCredentialPair.get_fence_key())
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = r.scard(RedisConnectorCredentialPair.get_taskset_key())
task_logger.info(
f"Stale document sync progress: remaining={count} initial={initial_count}"
)
if count == 0:
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
task_logger.info(f"Successfully synced stale documents. count={initial_count}")
def monitor_document_set_taskset(
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
if document_set_id_str is None:
task_logger.warning(f"could not parse document set id from {fence_key}")
return
document_set_id = int(document_set_id_str)
rds = RedisDocumentSet(document_set_id)
fence_value = r.get(rds.fence_key)
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rds.taskset_key))
task_logger.info(
f"Document set sync progress: document_set_id={document_set_id} "
f"remaining={count} initial={initial_count}"
)
if count > 0:
return
document_set = cast(
DocumentSet,
get_document_set_by_id(db_session=db_session, document_set_id=document_set_id),
) # casting since we "know" a document set with this ID exists
if document_set:
if not document_set.connector_credential_pairs:
# if there are no connectors, then delete the document set.
delete_document_set(document_set_row=document_set, db_session=db_session)
task_logger.info(
f"Successfully deleted document set with ID: '{document_set_id}'!"
)
else:
mark_document_set_as_synced(document_set_id, db_session)
task_logger.info(
f"Successfully synced document set with ID: '{document_set_id}'!"
)
r.delete(rds.taskset_key)
r.delete(rds.fence_key)
def monitor_connector_deletion_taskset(
key_bytes: bytes, r: Redis, tenant_id: str | None
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
return
cc_pair_id = int(cc_pair_id_str)
rcd = RedisConnectorDeletion(cc_pair_id)
fence_value = r.get(rcd.fence_key)
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rcd.taskset_key))
task_logger.info(
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if not cc_pair:
task_logger.warning(
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
)
return
try:
doc_ids = get_document_ids_for_connector_credential_pair(
db_session, cc_pair.connector_id, cc_pair.credential_id
)
if len(doc_ids) > 0:
# if this happens, documents somehow got added while deletion was in progress. Likely a bug
# gating off pruning and indexing work before deletion starts
task_logger.warning(
f"Connector deletion - documents still found after taskset completion: "
f"cc_pair={cc_pair_id} num={len(doc_ids)}"
)
# clean up the rest of the related Postgres entities
# index attempts
delete_index_attempts(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
# document sets
delete_document_set_cc_pair_relationship__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# user groups
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
"danswer.db.user_group",
"delete_user_group_cc_pair_relationship__no_commit",
noop_fallback,
)
cleanup_user_groups(
cc_pair_id=cc_pair_id,
db_session=db_session,
)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# if there are no credentials left, delete the connector
connector = fetch_connector_by_id(
db_session=db_session,
connector_id=cc_pair.connector_id,
)
if not connector or not len(connector.credentials):
task_logger.info(
"Found no credentials left for connector, deleting connector"
)
db_session.delete(connector)
db_session.commit()
except Exception as e:
db_session.rollback()
stack_trace = traceback.format_exc()
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
add_deletion_failure_message(db_session, cc_pair_id, error_message)
task_logger.exception(
f"Failed to run connector_deletion. "
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
)
raise e
task_logger.info(
f"Successfully deleted cc_pair: "
f"cc_pair={cc_pair_id} "
f"connector={cc_pair.connector_id} "
f"credential={cc_pair.credential_id} "
f"docs_deleted={initial_count}"
)
r.delete(rcd.taskset_key)
r.delete(rcd.fence_key)
def monitor_ccpair_pruning_taskset(
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnectorPruning.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
rcp = RedisConnectorPruning(cc_pair_id)
fence_value = r.get(rcp.fence_key)
if fence_value is None:
return
generator_value = r.get(rcp.generator_complete_key)
if generator_value is None:
return
try:
initial_count = int(cast(int, generator_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rcp.taskset_key))
task_logger.info(
f"Connector pruning progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
task_logger.info(
f"Successfully pruned connector credential pair. cc_pair_id={cc_pair_id}"
)
r.delete(rcp.taskset_key)
r.delete(rcp.generator_progress_key)
r.delete(rcp.generator_complete_key)
r.delete(rcp.fence_key)
def monitor_ccpair_indexing_taskset(
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
composite_id = RedisConnectorIndexing.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"monitor_ccpair_indexing_taskset: could not parse composite_id from {fence_key}"
)
return
# parse out metadata and initialize the helper class with it
parts = composite_id.split("/")
if len(parts) != 2:
return
cc_pair_id = int(parts[0])
search_settings_id = int(parts[1])
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
# read related data and evaluate/print task progress
fence_value = cast(bytes, r.get(rci.fence_key))
if fence_value is None:
return
try:
fence_json = fence_value.decode("utf-8")
fence_data = RedisConnectorIndexingFenceData.model_validate_json(
cast(str, fence_json)
)
except ValueError:
task_logger.exception(
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
)
raise
elapsed_submitted = datetime.now(timezone.utc) - fence_data.submitted
generator_progress_value = r.get(rci.generator_progress_key)
if generator_progress_value is not None:
try:
progress_count = int(cast(int, generator_progress_value))
task_logger.info(
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"progress={progress_count} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
except ValueError:
task_logger.error(
"monitor_ccpair_indexing_taskset: generator_progress_value is not an integer."
)
# Read result state BEFORE generator_complete_key to avoid a race condition
result: AsyncResult = AsyncResult(fence_data.celery_task_id)
result_state = result.state
generator_complete_value = r.get(rci.generator_complete_key)
if generator_complete_value is None:
if result_state in READY_STATES:
# IF the task state is READY, THEN generator_complete should be set
# if it isn't, then the worker crashed
task_logger.info(
f"Connector indexing aborted: "
f"cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
index_attempt = get_index_attempt(db_session, fence_data.index_attempt_id)
if index_attempt:
mark_attempt_failed(
index_attempt=index_attempt,
db_session=db_session,
failure_reason="Connector indexing aborted or exceptioned.",
)
r.delete(rci.generator_lock_key)
r.delete(rci.taskset_key)
r.delete(rci.generator_progress_key)
r.delete(rci.generator_complete_key)
r.delete(rci.fence_key)
return
status_enum = HTTPStatus.INTERNAL_SERVER_ERROR
try:
status_value = int(cast(int, generator_complete_value))
status_enum = HTTPStatus(status_value)
except ValueError:
task_logger.error(
f"monitor_ccpair_indexing_taskset: "
f"generator_complete_value=f{generator_complete_value} could not be parsed."
)
task_logger.info(
f"Connector indexing finished: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
r.delete(rci.generator_lock_key)
r.delete(rci.taskset_key)
r.delete(rci.generator_progress_key)
r.delete(rci.generator_complete_key)
r.delete(rci.fence_key)
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
It scans for fence values and then gets the counts of any associated tasksets.
If the count is 0, that means all tasks finished and we should clean up.
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
do anything too expensive in this function!
Returns True if the task actually did work, False
"""
r = get_redis_client(tenant_id=tenant_id)
lock_beat: redis.lock.Lock = r.lock(
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# prevent overlapping tasks
if not lock_beat.acquire(blocking=False):
return False
# print current queue lengths
r_celery = self.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r)
n_indexing = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery
)
n_sync = celery_get_queue_length(
DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery
)
n_deletion = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_DELETION, r_celery
)
n_pruning = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_PRUNING, r_celery
)
task_logger.info(
f"Queue lengths: celery={n_celery} "
f"indexing={n_indexing} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning}"
)
lock_beat.reacquire()
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
monitor_connector_deletion_taskset(key_bytes, r, tenant_id)
with get_session_with_tenant(tenant_id) as db_session:
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
monitor_document_set_taskset(key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
monitor_usergroup_taskset = (
fetch_versioned_implementation_with_fallback(
"danswer.background.celery.tasks.vespa.tasks",
"monitor_usergroup_taskset",
noop_fallback,
)
)
monitor_usergroup_taskset(key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
# do some cleanup before clearing fences
# check the db for any outstanding index attempts
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for a in attempts:
# if attempts exist in the db but we don't detect them in redis, mark them as failed
rci = RedisConnectorIndexing(
a.connector_credential_pair_id, a.search_settings_id
)
failure_reason = f"Unknown index attempt {a.id}. Might be left over from a process restart."
if not r.exists(rci.fence_key):
mark_attempt_failed(a, db_session, failure_reason=failure_reason)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
monitor_ccpair_indexing_taskset(key_bytes, r, db_session)
# uncomment for debugging if needed
# r_celery = celery_app.broker_connection().channel().client
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
# task_logger.warning(f"queue={DanswerCeleryQueues.VESPA_METADATA_SYNC} length={length}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
finally:
if lock_beat.owned():
lock_beat.release()
return True
@shared_task(
name="vespa_metadata_sync_task",
bind=True,
soft_time_limit=45,
time_limit=60,
max_retries=3,
)
def vespa_metadata_sync_task(
self: Task, document_id: str, tenant_id: str | None
) -> bool:
task_logger.info(f"document_id={document_id}")
try:
with get_session_with_tenant(tenant_id) as db_session:
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
doc = get_document(document_id, db_session)
if not doc:
return False
# document set sync
doc_sets = fetch_document_sets_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
# User group sync
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
fields = VespaDocumentFields(
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
chunks_affected = document_index.update_single(document_id, fields=fields)
# update db last. Worst case = we crash right before this and
# the sync might repeat again later
mark_document_as_synced(document_id, db_session)
task_logger.info(
f"document_id={document_id} action=sync chunks={chunks_affected}"
)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
except Exception as e:
task_logger.exception("Unexpected exception")
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown)
return True

View File

@@ -0,0 +1,196 @@
"""
To delete a connector / credential pair:
(1) find all documents associated with connector / credential pair where there
this the is only connector / credential pair that has indexed it
(2) delete all documents from document stores
(3) delete all entries from postgres
(4) find all documents associated with connector / credential pair where there
are multiple connector / credential pairs that have indexed it
(5) update document store entries to remove access associated with the
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_documents
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit,
)
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document_connector_cnts
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.models import ConnectorCredentialPair
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import UpdateRequest
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from danswer.utils.variable_functionality import noop_fallback
logger = setup_logger()
_DELETION_BATCH_SIZE = 1000
def delete_connector_credential_pair_batch(
document_ids: list[str],
connector_id: int,
credential_id: int,
document_index: DocumentIndex,
) -> None:
"""
Removes a batch of documents ids from a cc-pair. If no other cc-pair uses a document anymore
it gets permanently deleted.
"""
with Session(get_sqlalchemy_engine()) as db_session:
# acquire lock for all documents in this batch so that indexing can't
# override the deletion
with prepare_to_modify_documents(
db_session=db_session, document_ids=document_ids
):
document_connector_cnts = get_document_connector_cnts(
db_session=db_session, document_ids=document_ids
)
# figure out which docs need to be completely deleted
document_ids_to_delete = [
document_id for document_id, cnt in document_connector_cnts if cnt == 1
]
logger.debug(f"Deleting documents: {document_ids_to_delete}")
document_index.delete(doc_ids=document_ids_to_delete)
delete_documents_complete__no_commit(
db_session=db_session,
document_ids=document_ids_to_delete,
)
# figure out which docs need to be updated
document_ids_to_update = [
document_id for document_id, cnt in document_connector_cnts if cnt > 1
]
# maps document id to list of document set names
new_doc_sets_for_documents: dict[str, set[str]] = {
document_id_and_document_set_names_tuple[0]: set(
document_id_and_document_set_names_tuple[1]
)
for document_id_and_document_set_names_tuple in fetch_document_sets_for_documents(
db_session=db_session,
document_ids=document_ids_to_update,
)
}
# determine future ACLs for documents in batch
access_for_documents = get_access_for_documents(
document_ids=document_ids_to_update,
db_session=db_session,
)
# update Vespa
logger.debug(f"Updating documents: {document_ids_to_update}")
update_requests = [
UpdateRequest(
document_ids=[document_id],
access=access,
document_sets=new_doc_sets_for_documents[document_id],
)
for document_id, access in access_for_documents.items()
]
document_index.update(update_requests=update_requests)
# clean up Postgres
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_ids=document_ids_to_update,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
db_session.commit()
def delete_connector_credential_pair(
db_session: Session,
document_index: DocumentIndex,
cc_pair: ConnectorCredentialPair,
) -> int:
connector_id = cc_pair.connector_id
credential_id = cc_pair.credential_id
num_docs_deleted = 0
while True:
documents = get_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
limit=_DELETION_BATCH_SIZE,
)
if not documents:
break
delete_connector_credential_pair_batch(
document_ids=[document.id for document in documents],
connector_id=connector_id,
credential_id=credential_id,
document_index=document_index,
)
num_docs_deleted += len(documents)
# clean up the rest of the related Postgres entities
# index attempts
delete_index_attempts(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
# document sets
delete_document_set_cc_pair_relationship__no_commit(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
# user groups
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
"danswer.db.user_group",
"delete_user_group_cc_pair_relationship__no_commit",
noop_fallback,
)
cleanup_user_groups(
cc_pair_id=cc_pair.id,
db_session=db_session,
)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
# if there are no credentials left, delete the connector
connector = fetch_connector_by_id(
db_session=db_session,
connector_id=connector_id,
)
if not connector or not len(connector.credentials):
logger.debug("Found no credentials left for connector, deleting connector")
db_session.delete(connector)
db_session.commit()
logger.info(
"Successfully deleted connector_credential_pair with connector_id:"
f" '{connector_id}' and credential_id: '{credential_id}'. Deleted {num_docs_deleted} docs."
)
return num_docs_deleted

View File

@@ -1,6 +1,5 @@
import time
import traceback
from collections.abc import Callable
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -12,25 +11,26 @@ from danswer.background.indexing.tracer import DanswerTracer
from danswer.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
from danswer.configs.app_configs import INDEXING_TRACER_INTERVAL
from danswer.configs.app_configs import POLL_CONNECTOR_OFFSET
from danswer.connectors.connector_runner import ConnectorRunner
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import IndexAttemptMetadata
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.connectors.models import InputType
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_partially_succeeded
from danswer.db.index_attempt import mark_attempt_in_progress
from danswer.db.index_attempt import mark_attempt_succeeded
from danswer.db.index_attempt import transition_attempt_to_in_progress
from danswer.db.index_attempt import update_docs_indexed
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import IndexAttemptSingleton
from danswer.utils.logger import setup_logger
@@ -41,17 +41,16 @@ logger = setup_logger()
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
def _get_connector_runner(
def _get_document_generator(
db_session: Session,
attempt: IndexAttempt,
start_time: datetime,
end_time: datetime,
tenant_id: str | None,
) -> ConnectorRunner:
) -> GenerateDocumentsOutput:
"""
NOTE: `start_time` and `end_time` are only used for poll connectors
Returns an iterator of document batches and whether the returned documents
Returns an interator of document batches and whether the returned documents
are the complete list of existing documents of the connector. If the task
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
"""
@@ -59,117 +58,111 @@ def _get_connector_runner(
try:
runnable_connector = instantiate_connector(
db_session=db_session,
source=attempt.connector_credential_pair.connector.source,
input_type=task,
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
credential=attempt.connector_credential_pair.credential,
tenant_id=tenant_id,
attempt.connector_credential_pair.connector.source,
task,
attempt.connector_credential_pair.connector.connector_specific_config,
attempt.connector_credential_pair.credential,
db_session,
)
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")
# since we failed to even instantiate the connector, we pause the CCPair since
# it will never succeed
cc_pair = get_connector_credential_pair_from_id(
attempt.connector_credential_pair.id, db_session
update_connector_credential_pair(
db_session=db_session,
connector_id=attempt.connector_credential_pair.connector.id,
credential_id=attempt.connector_credential_pair.credential.id,
status=ConnectorCredentialPairStatus.PAUSED,
)
if cc_pair and cc_pair.status == ConnectorCredentialPairStatus.ACTIVE:
update_connector_credential_pair(
db_session=db_session,
connector_id=attempt.connector_credential_pair.connector.id,
credential_id=attempt.connector_credential_pair.credential.id,
status=ConnectorCredentialPairStatus.PAUSED,
)
raise e
return ConnectorRunner(
connector=runnable_connector, time_range=(start_time, end_time)
)
if task == InputType.LOAD_STATE:
assert isinstance(runnable_connector, LoadConnector)
doc_batch_generator = runnable_connector.load_from_state()
elif task == InputType.POLL:
assert isinstance(runnable_connector, PollConnector)
if (
attempt.connector_credential_pair.connector_id is None
or attempt.connector_credential_pair.connector_id is None
):
raise ValueError(
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
f"can't fetch time range."
)
logger.info(f"Polling for updates between {start_time} and {end_time}")
doc_batch_generator = runnable_connector.poll_source(
start=start_time.timestamp(), end=end_time.timestamp()
)
else:
# Event types cannot be handled by a background type
raise RuntimeError(f"Invalid task type: {task}")
return doc_batch_generator
def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
tenant_id: str | None,
progress_callback: Callable[[int], None] | None = None,
) -> None:
"""
1. Get documents which are either new or updated from specified application
2. Embed and index these documents into the chosen datastore (vespa)
3. Updates Postgres to record the indexed documents + the outcome of this run
TODO: do not change index attempt statuses here ... instead, set signals in redis
and allow the monitor function to clean them up
"""
start_time = time.time()
search_settings = index_attempt.search_settings
index_name = search_settings.index_name
db_embedding_model = index_attempt.embedding_model
index_name = db_embedding_model.index_name
# Only update cc-pair status for primary index jobs
# Secondary index syncs at the end when swapping
is_primary = search_settings.status == IndexModelStatus.PRESENT
is_primary = index_attempt.embedding_model.status == IndexModelStatus.PRESENT
# Indexing is only done into one index at a time
document_index = get_default_document_index(
primary_index_name=index_name, secondary_index_name=None
)
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings,
heartbeat=IndexingHeartbeat(
index_attempt_id=index_attempt.id,
db_session=db_session,
# let the world know we're still making progress after
# every 10 batches
freq=10,
),
embedding_model = DefaultIndexingEmbedder.from_db_embedding_model(
db_embedding_model
)
indexing_pipeline = build_indexing_pipeline(
attempt_id=index_attempt.id,
embedder=embedding_model,
document_index=document_index,
ignore_time_skip=(
index_attempt.from_beginning
or (search_settings.status == IndexModelStatus.FUTURE)
),
ignore_time_skip=index_attempt.from_beginning
or (db_embedding_model.status == IndexModelStatus.FUTURE),
db_session=db_session,
tenant_id=tenant_id,
)
db_cc_pair = index_attempt.connector_credential_pair
db_connector = index_attempt.connector_credential_pair.connector
db_credential = index_attempt.connector_credential_pair.credential
earliest_index_time = (
db_connector.indexing_start.timestamp() if db_connector.indexing_start else 0
)
last_successful_index_time = (
earliest_index_time
if index_attempt.from_beginning
else get_last_successful_attempt_time(
connector_id=db_connector.id,
credential_id=db_credential.id,
earliest_index=earliest_index_time,
search_settings=index_attempt.search_settings,
db_session=db_session,
db_connector.indexing_start.timestamp()
if index_attempt.from_beginning and db_connector.indexing_start is not None
else (
0.0
if index_attempt.from_beginning
else get_last_successful_attempt_time(
connector_id=db_connector.id,
credential_id=db_credential.id,
embedding_model=index_attempt.embedding_model,
db_session=db_session,
)
)
)
if INDEXING_TRACER_INTERVAL > 0:
logger.debug(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
logger.info(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
tracer = DanswerTracer()
tracer.start()
tracer.snap()
index_attempt_md = IndexAttemptMetadata(
connector_id=db_connector.id,
credential_id=db_credential.id,
)
batch_num = 0
net_doc_change = 0
document_count = 0
chunk_count = 0
@@ -188,12 +181,11 @@ def _run_indexing(
datetime(1970, 1, 1, tzinfo=timezone.utc),
)
connector_runner = _get_connector_runner(
doc_batch_generator = _get_document_generator(
db_session=db_session,
attempt=index_attempt,
start_time=window_start,
end_time=window_end,
tenant_id=tenant_id,
)
all_connector_doc_ids: set[str] = set()
@@ -201,19 +193,15 @@ def _run_indexing(
tracer_counter = 0
if INDEXING_TRACER_INTERVAL > 0:
tracer.snap()
for doc_batch in connector_runner.run():
for doc_batch in doc_batch_generator:
# Check if connector is disabled mid run and stop if so unless it's the secondary
# index being built. We want to populate it even for paused connectors
# Often paused connectors are sources that aren't updated frequently but the
# contents still need to be initially pulled.
db_session.refresh(db_cc_pair)
db_session.refresh(db_connector)
if (
(
db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
and search_settings.status != IndexModelStatus.FUTURE
)
# if it's deleting, we don't care if this is a secondary index
or db_cc_pair.status == ConnectorCredentialPairStatus.DELETING
db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
and db_embedding_model.status != IndexModelStatus.FUTURE
):
# let the `except` block handle this
raise RuntimeError("Connector was disabled mid run")
@@ -221,9 +209,7 @@ def _run_indexing(
db_session.refresh(index_attempt)
if index_attempt.status != IndexingStatus.IN_PROGRESS:
# Likely due to user manually disabling it or model swap
raise RuntimeError(
f"Index Attempt was canceled, status is {index_attempt.status}"
)
raise RuntimeError("Index Attempt was canceled")
batch_description = []
for doc in doc_batch:
@@ -242,15 +228,13 @@ def _run_indexing(
logger.debug(f"Indexing batch of documents: {batch_description}")
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
# real work happens here!
new_docs, total_batch_chunks = indexing_pipeline(
document_batch=doc_batch,
index_attempt_metadata=index_attempt_md,
index_attempt_metadata=IndexAttemptMetadata(
connector_id=db_connector.id,
credential_id=db_credential.id,
),
)
batch_num += 1
net_doc_change += new_docs
chunk_count += total_batch_chunks
document_count += len(doc_batch)
@@ -263,9 +247,6 @@ def _run_indexing(
# be inaccurate
db_session.commit()
if progress_callback:
progress_callback(len(doc_batch))
# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(
db_session=db_session,
@@ -280,7 +261,7 @@ def _run_indexing(
INDEXING_TRACER_INTERVAL > 0
and tracer_counter % INDEXING_TRACER_INTERVAL == 0
):
logger.debug(
logger.info(
f"Running trace comparison for batch {tracer_counter}. interval={INDEXING_TRACER_INTERVAL}"
)
tracer.snap()
@@ -296,7 +277,7 @@ def _run_indexing(
run_dt=run_end_dt,
)
except Exception as e:
logger.exception(
logger.info(
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
)
# Only mark the attempt as a complete failure if this is the first indexing window.
@@ -308,7 +289,7 @@ def _run_indexing(
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or not db_cc_pair.status.is_active()
or db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
@@ -334,52 +315,15 @@ def _run_indexing(
break
if INDEXING_TRACER_INTERVAL > 0:
logger.debug(
logger.info(
f"Running trace comparison between start and end of indexing. {tracer_counter} batches processed."
)
tracer.snap()
tracer.log_first_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
tracer.stop()
logger.debug("Memory tracer stopped.")
if (
index_attempt_md.num_exceptions > 0
and index_attempt_md.num_exceptions >= batch_num
):
mark_attempt_failed(
index_attempt,
db_session,
failure_reason="All batches exceptioned.",
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=index_attempt.connector_credential_pair.connector.id,
credential_id=index_attempt.connector_credential_pair.credential.id,
)
raise Exception(
f"Connector failed - All batches exceptioned: batches={batch_num}"
)
elapsed_time = time.time() - start_time
if index_attempt_md.num_exceptions == 0:
mark_attempt_succeeded(index_attempt, db_session)
logger.info(
f"Connector succeeded: "
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
)
else:
mark_attempt_partially_succeeded(index_attempt, db_session)
logger.info(
f"Connector completed with some errors: "
f"exceptions={index_attempt_md.num_exceptions} "
f"batches={batch_num} "
f"docs={document_count} "
f"chunks={chunk_count} "
f"elapsed={elapsed_time:.2f}s"
)
logger.info("Memory tracer stopped.")
mark_attempt_succeeded(index_attempt, db_session)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
@@ -388,46 +332,70 @@ def _run_indexing(
run_dt=run_end_dt,
)
elapsed_time = time.time() - start_time
logger.info(
f"Connector succeeded: docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
)
def run_indexing_entrypoint(
index_attempt_id: int,
tenant_id: str | None,
connector_credential_pair_id: int,
is_ee: bool = False,
progress_callback: Callable[[int], None] | None = None,
) -> None:
def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexAttempt:
# make sure that the index attempt can't change in between checking the
# status and marking it as in_progress. This setting will be discarded
# after the next commit:
# https://docs.sqlalchemy.org/en/20/orm/session_transaction.html#setting-isolation-for-individual-transactions
db_session.connection(execution_options={"isolation_level": "SERIALIZABLE"}) # type: ignore
attempt = get_index_attempt(
db_session=db_session,
index_attempt_id=index_attempt_id,
)
if attempt is None:
raise RuntimeError(f"Unable to find IndexAttempt for ID '{index_attempt_id}'")
if attempt.status != IndexingStatus.NOT_STARTED:
raise RuntimeError(
f"Indexing attempt with ID '{index_attempt_id}' is not in NOT_STARTED status. "
f"Current status is '{attempt.status}'."
)
# only commit once, to make sure this all happens in a single transaction
mark_attempt_in_progress(attempt, db_session)
return attempt
def run_indexing_entrypoint(index_attempt_id: int, is_ee: bool = False) -> None:
"""Entrypoint for indexing run when using dask distributed.
Wraps the actual logic in a `try` block so that we can catch any exceptions
and mark the attempt as failed."""
try:
if is_ee:
global_version.set_ee()
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
IndexAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
)
with get_session_with_tenant(tenant_id) as db_session:
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
with Session(get_sqlalchemy_engine()) as db_session:
# make sure that it is valid to run this indexing attempt + mark it
# as in progress
attempt = _prepare_index_attempt(db_session, index_attempt_id)
logger.info(
f"Indexing starting for tenant {tenant_id}: "
if tenant_id is not None
else ""
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
f"Indexing starting: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
_run_indexing(db_session, attempt, tenant_id, progress_callback)
_run_indexing(db_session, attempt)
logger.info(
f"Indexing finished for tenant {tenant_id}: "
if tenant_id is not None
else ""
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
f"Indexing finished: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
except Exception as e:
logger.exception(
f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}"
)
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")

View File

@@ -48,9 +48,9 @@ class DanswerTracer:
stats = self.snapshot.statistics("traceback")
for s in stats[:numEntries]:
logger.debug(f"Tracer snap: {s}")
logger.info(f"Tracer snap: {s}")
for line in s.traceback:
logger.debug(f"* {line}")
logger.info(f"* {line}")
@staticmethod
def log_diff(
@@ -60,9 +60,9 @@ class DanswerTracer:
) -> None:
stats = snap_current.compare_to(snap_previous, "traceback")
for s in stats[:numEntries]:
logger.debug(f"Tracer diff: {s}")
logger.info(f"Tracer diff: {s}")
for line in s.traceback.format():
logger.debug(f"* {line}")
logger.info(f"* {line}")
def log_previous_diff(self, numEntries: int) -> None:
if not self.snapshot or not self.snapshot_prev:

View File

@@ -14,6 +14,14 @@ from danswer.db.tasks import mark_task_start
from danswer.db.tasks import register_task
def name_cc_cleanup_task(connector_id: int, credential_id: int) -> str:
return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"
def name_document_set_sync_task(document_set_id: int) -> str:
return f"sync_doc_set_{document_set_id}"
def name_cc_prune_task(
connector_id: int | None = None, credential_id: int | None = None
) -> str:
@@ -85,16 +93,9 @@ def build_apply_async_wrapper(build_name_fn: Callable[..., str]) -> Callable[[AA
kwargs_for_build_name = kwargs or {}
task_name = build_name_fn(*args_for_build_name, **kwargs_for_build_name)
with Session(get_sqlalchemy_engine()) as db_session:
# register_task must come before fn = apply_async or else the task
# might run mark_task_start (and crash) before the task row exists
db_task = register_task(task_name, db_session)
# mark the task as started
task = fn(args, kwargs, *other_args, **other_kwargs)
# we update the celery task id for diagnostic purposes
# but it isn't currently used by any code
db_task.task_id = task.id
db_session.commit()
register_task(task.id, task_name, db_session)
return task

View File

@@ -1,494 +1,462 @@
# TODO(rkuo): delete after background indexing via celery is fully vetted
# import logging
# import time
# from datetime import datetime
# import dask
# from dask.distributed import Client
# from dask.distributed import Future
# from distributed import LocalCluster
# from sqlalchemy import text
# from sqlalchemy.exc import ProgrammingError
# from sqlalchemy.orm import Session
# from danswer.background.indexing.dask_utils import ResourceLogger
# from danswer.background.indexing.job_client import SimpleJob
# from danswer.background.indexing.job_client import SimpleJobClient
# from danswer.background.indexing.run_indexing import run_indexing_entrypoint
# from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
# from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
# from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
# from danswer.configs.app_configs import MULTI_TENANT
# from danswer.configs.app_configs import NUM_INDEXING_WORKERS
# from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
# from danswer.configs.constants import DocumentSource
# from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
# from danswer.configs.constants import TENANT_ID_PREFIX
# from danswer.db.connector import fetch_connectors
# from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
# from danswer.db.engine import get_db_current_time
# from danswer.db.engine import get_session_with_tenant
# from danswer.db.engine import get_sqlalchemy_engine
# from danswer.db.engine import SqlEngine
# from danswer.db.index_attempt import create_index_attempt
# from danswer.db.index_attempt import get_index_attempt
# from danswer.db.index_attempt import get_inprogress_index_attempts
# from danswer.db.index_attempt import get_last_attempt_for_cc_pair
# from danswer.db.index_attempt import get_not_started_index_attempts
# from danswer.db.index_attempt import mark_attempt_failed
# from danswer.db.models import ConnectorCredentialPair
# from danswer.db.models import IndexAttempt
# from danswer.db.models import IndexingStatus
# from danswer.db.models import IndexModelStatus
# from danswer.db.models import SearchSettings
# from danswer.db.search_settings import get_current_search_settings
# from danswer.db.search_settings import get_secondary_search_settings
# from danswer.db.swap_index import check_index_swap
# from danswer.document_index.vespa.index import VespaIndex
# from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
# from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
# from danswer.utils.logger import setup_logger
# from danswer.utils.variable_functionality import global_version
# from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
# from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
# from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
# from shared_configs.configs import LOG_LEVEL
# logger = setup_logger()
# # If the indexing dies, it's most likely due to resource constraints,
# # restarting just delays the eventual failure, not useful to the user
# dask.config.set({"distributed.scheduler.allowed-failures": 0})
# _UNEXPECTED_STATE_FAILURE_REASON = (
# "Stopped mid run, likely due to the background process being killed"
# )
# def _should_create_new_indexing(
# cc_pair: ConnectorCredentialPair,
# last_index: IndexAttempt | None,
# search_settings_instance: SearchSettings,
# secondary_index_building: bool,
# db_session: Session,
# ) -> bool:
# connector = cc_pair.connector
# # don't kick off indexing for `NOT_APPLICABLE` sources
# if connector.source == DocumentSource.NOT_APPLICABLE:
# return False
# # User can still manually create single indexing attempts via the UI for the
# # currently in use index
# if DISABLE_INDEX_UPDATE_ON_SWAP:
# if (
# search_settings_instance.status == IndexModelStatus.PRESENT
# and secondary_index_building
# ):
# return False
# # When switching over models, always index at least once
# if search_settings_instance.status == IndexModelStatus.FUTURE:
# if last_index:
# # No new index if the last index attempt succeeded
# # Once is enough. The model will never be able to swap otherwise.
# if last_index.status == IndexingStatus.SUCCESS:
# return False
# # No new index if the last index attempt is waiting to start
# if last_index.status == IndexingStatus.NOT_STARTED:
# return False
# # No new index if the last index attempt is running
# if last_index.status == IndexingStatus.IN_PROGRESS:
# return False
# else:
# if (
# connector.id == 0 or connector.source == DocumentSource.INGESTION_API
# ): # Ingestion API
# return False
# return True
# # If the connector is paused or is the ingestion API, don't index
# # NOTE: during an embedding model switch over, the following logic
# # is bypassed by the above check for a future model
# if (
# not cc_pair.status.is_active()
# or connector.id == 0
# or connector.source == DocumentSource.INGESTION_API
# ):
# return False
# if not last_index:
# return True
# if connector.refresh_freq is None:
# return False
# # Only one scheduled/ongoing job per connector at a time
# # this prevents cases where
# # (1) the "latest" index_attempt is scheduled so we show
# # that in the UI despite another index_attempt being in-progress
# # (2) multiple scheduled index_attempts at a time
# if (
# last_index.status == IndexingStatus.NOT_STARTED
# or last_index.status == IndexingStatus.IN_PROGRESS
# ):
# return False
# current_db_time = get_db_current_time(db_session)
# time_since_index = current_db_time - last_index.time_updated
# return time_since_index.total_seconds() >= connector.refresh_freq
# def _mark_run_failed(
# db_session: Session, index_attempt: IndexAttempt, failure_reason: str
# ) -> None:
# """Marks the `index_attempt` row as failed + updates the `
# connector_credential_pair` to reflect that the run failed"""
# logger.warning(
# f"Marking in-progress attempt 'connector: {index_attempt.connector_credential_pair.connector_id}, "
# f"credential: {index_attempt.connector_credential_pair.credential_id}' as failed due to {failure_reason}"
# )
# mark_attempt_failed(
# index_attempt=index_attempt,
# db_session=db_session,
# failure_reason=failure_reason,
# )
# """Main funcs"""
# def create_indexing_jobs(
# existing_jobs: dict[int, Future | SimpleJob], tenant_id: str | None
# ) -> None:
# """Creates new indexing jobs for each connector / credential pair which is:
# 1. Enabled
# 2. `refresh_frequency` time has passed since the last indexing run for this pair
# 3. There is not already an ongoing indexing attempt for this pair
# """
# with get_session_with_tenant(tenant_id) as db_session:
# ongoing: set[tuple[int | None, int]] = set()
# for attempt_id in existing_jobs:
# attempt = get_index_attempt(
# db_session=db_session, index_attempt_id=attempt_id
# )
# if attempt is None:
# logger.error(
# f"Unable to find IndexAttempt for ID '{attempt_id}' when creating "
# "indexing jobs"
# )
# continue
# ongoing.add(
# (
# attempt.connector_credential_pair_id,
# attempt.search_settings_id,
# )
# )
# # Get the primary search settings
# primary_search_settings = get_current_search_settings(db_session)
# search_settings = [primary_search_settings]
# # Check for secondary search settings
# secondary_search_settings = get_secondary_search_settings(db_session)
# if secondary_search_settings is not None:
# # If secondary settings exist, add them to the list
# search_settings.append(secondary_search_settings)
# all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
# for cc_pair in all_connector_credential_pairs:
# for search_settings_instance in search_settings:
# # Check if there is an ongoing indexing attempt for this connector credential pair
# if (cc_pair.id, search_settings_instance.id) in ongoing:
# continue
# last_attempt = get_last_attempt_for_cc_pair(
# cc_pair.id, search_settings_instance.id, db_session
# )
# if not _should_create_new_indexing(
# cc_pair=cc_pair,
# last_index=last_attempt,
# search_settings_instance=search_settings_instance,
# secondary_index_building=len(search_settings) > 1,
# db_session=db_session,
# ):
# continue
# create_index_attempt(
# cc_pair.id, search_settings_instance.id, db_session
# )
# def cleanup_indexing_jobs(
# existing_jobs: dict[int, Future | SimpleJob],
# tenant_id: str | None,
# timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
# ) -> dict[int, Future | SimpleJob]:
# existing_jobs_copy = existing_jobs.copy()
# # clean up completed jobs
# with get_session_with_tenant(tenant_id) as db_session:
# for attempt_id, job in existing_jobs.items():
# index_attempt = get_index_attempt(
# db_session=db_session, index_attempt_id=attempt_id
# )
# # do nothing for ongoing jobs that haven't been stopped
# if not job.done():
# if not index_attempt:
# continue
# if not index_attempt.is_finished():
# continue
# if job.status == "error":
# logger.error(job.exception())
# job.release()
# del existing_jobs_copy[attempt_id]
# if not index_attempt:
# logger.error(
# f"Unable to find IndexAttempt for ID '{attempt_id}' when cleaning "
# "up indexing jobs"
# )
# continue
# if (
# index_attempt.status == IndexingStatus.IN_PROGRESS
# or job.status == "error"
# ):
# _mark_run_failed(
# db_session=db_session,
# index_attempt=index_attempt,
# failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
# )
# # clean up in-progress jobs that were never completed
# try:
# connectors = fetch_connectors(db_session)
# for connector in connectors:
# in_progress_indexing_attempts = get_inprogress_index_attempts(
# connector.id, db_session
# )
# for index_attempt in in_progress_indexing_attempts:
# if index_attempt.id in existing_jobs:
# # If index attempt is canceled, stop the run
# if index_attempt.status == IndexingStatus.FAILED:
# existing_jobs[index_attempt.id].cancel()
# # check to see if the job has been updated in last `timeout_hours` hours, if not
# # assume it to frozen in some bad state and just mark it as failed. Note: this relies
# # on the fact that the `time_updated` field is constantly updated every
# # batch of documents indexed
# current_db_time = get_db_current_time(db_session=db_session)
# time_since_update = current_db_time - index_attempt.time_updated
# if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
# existing_jobs[index_attempt.id].cancel()
# _mark_run_failed(
# db_session=db_session,
# index_attempt=index_attempt,
# failure_reason="Indexing run frozen - no updates in the last three hours. "
# "The run will be re-attempted at next scheduled indexing time.",
# )
# else:
# # If job isn't known, simply mark it as failed
# _mark_run_failed(
# db_session=db_session,
# index_attempt=index_attempt,
# failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
# )
# except ProgrammingError:
# logger.debug(f"No Connector Table exists for: {tenant_id}")
# return existing_jobs_copy
# def kickoff_indexing_jobs(
# existing_jobs: dict[int, Future | SimpleJob],
# client: Client | SimpleJobClient,
# secondary_client: Client | SimpleJobClient,
# tenant_id: str | None,
# ) -> dict[int, Future | SimpleJob]:
# existing_jobs_copy = existing_jobs.copy()
# current_session = get_session_with_tenant(tenant_id)
# # Don't include jobs waiting in the Dask queue that just haven't started running
# # Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
# with current_session as db_session:
# # get_not_started_index_attempts orders its returned results from oldest to newest
# # we must process attempts in a FIFO manner to prevent connector starvation
# new_indexing_attempts = [
# (attempt, attempt.search_settings)
# for attempt in get_not_started_index_attempts(db_session)
# if attempt.id not in existing_jobs
# ]
# logger.debug(f"Found {len(new_indexing_attempts)} new indexing task(s).")
# if not new_indexing_attempts:
# return existing_jobs
# indexing_attempt_count = 0
# primary_client_full = False
# secondary_client_full = False
# for attempt, search_settings in new_indexing_attempts:
# if primary_client_full and secondary_client_full:
# break
# use_secondary_index = (
# search_settings.status == IndexModelStatus.FUTURE
# if search_settings is not None
# else False
# )
# if attempt.connector_credential_pair.connector is None:
# logger.warning(
# f"Skipping index attempt as Connector has been deleted: {attempt}"
# )
# with current_session as db_session:
# mark_attempt_failed(
# attempt, db_session, failure_reason="Connector is null"
# )
# continue
# if attempt.connector_credential_pair.credential is None:
# logger.warning(
# f"Skipping index attempt as Credential has been deleted: {attempt}"
# )
# with current_session as db_session:
# mark_attempt_failed(
# attempt, db_session, failure_reason="Credential is null"
# )
# continue
# if not use_secondary_index:
# if not primary_client_full:
# run = client.submit(
# run_indexing_entrypoint,
# attempt.id,
# tenant_id,
# attempt.connector_credential_pair_id,
# global_version.is_ee_version(),
# pure=False,
# )
# if not run:
# primary_client_full = True
# else:
# if not secondary_client_full:
# run = secondary_client.submit(
# run_indexing_entrypoint,
# attempt.id,
# tenant_id,
# attempt.connector_credential_pair_id,
# global_version.is_ee_version(),
# pure=False,
# )
# if not run:
# secondary_client_full = True
# if run:
# if indexing_attempt_count == 0:
# logger.info(
# f"Indexing dispatch starts: pending={len(new_indexing_attempts)}"
# )
# indexing_attempt_count += 1
# secondary_str = " (secondary index)" if use_secondary_index else ""
# logger.info(
# f"Indexing dispatched{secondary_str}: "
# f"attempt_id={attempt.id} "
# f"connector='{attempt.connector_credential_pair.connector.name}' "
# f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
# f"credentials='{attempt.connector_credential_pair.credential_id}'"
# )
# existing_jobs_copy[attempt.id] = run
# if indexing_attempt_count > 0:
# logger.info(
# f"Indexing dispatch results: "
# f"initial_pending={len(new_indexing_attempts)} "
# f"started={indexing_attempt_count} "
# f"remaining={len(new_indexing_attempts) - indexing_attempt_count}"
# )
# return existing_jobs_copy
# def get_all_tenant_ids() -> list[str] | list[None]:
# if not MULTI_TENANT:
# return [None]
# with get_session_with_tenant(tenant_id="public") as session:
# result = session.execute(
# text(
# """
# SELECT schema_name
# FROM information_schema.schemata
# WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')"""
# )
# )
# tenant_ids = [row[0] for row in result]
# valid_tenants = [
# tenant
# for tenant in tenant_ids
# if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
# ]
# return valid_tenants
# def update_loop(
# delay: int = 10,
# num_workers: int = NUM_INDEXING_WORKERS,
# num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
# ) -> None:
# if not MULTI_TENANT:
# # We can use this function as we are certain only the public schema exists
# # (explicitly for the non-`MULTI_TENANT` case)
# engine = get_sqlalchemy_engine()
# with Session(engine) as db_session:
# check_index_swap(db_session=db_session)
# search_settings = get_current_search_settings(db_session)
# # So that the first time users aren't surprised by really slow speed of first
# # batch of documents indexed
# if search_settings.provider_type is None:
# logger.notice("Running a first inference to warm up embedding model")
# embedding_model = EmbeddingModel.from_db_model(
# search_settings=search_settings,
# server_host=INDEXING_MODEL_SERVER_HOST,
# server_port=INDEXING_MODEL_SERVER_PORT,
# )
# warm_up_bi_encoder(
# embedding_model=embedding_model,
# )
# logger.notice("First inference complete.")
# client_primary: Client | SimpleJobClient
# client_secondary: Client | SimpleJobClient
# if DASK_JOB_CLIENT_ENABLED:
# cluster_primary = LocalCluster(
# n_workers=num_workers,
# threads_per_worker=1,
# silence_logs=logging.ERROR,
# )
# cluster_secondary = LocalCluster(
# n_workers=num_secondary_workers,
# threads_per_worker=1,
# silence_logs=logging.ERROR,
# )
# client_primary = Client(cluster_primary)
# client_secondary = Client(cluster_secondary)
# if LOG_LEVEL.lower() == "debug":
# client_primary.register_worker_plugin(ResourceLogger())
# else:
# client_primary = SimpleJobClient(n_workers=num_workers)
# client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
# existing_jobs: dict[str | None, dict[int, Future | SimpleJob]] = {}
# logger.notice("Startup complete. Waiting for indexing jobs...")
# while True:
# start = time.time()
# start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
# logger.debug(f"Running update, current UTC time: {start_time_utc}")
# if existing_jobs:
# logger.debug(
# "Found existing indexing jobs: "
# f"{[(tenant_id, list(jobs.keys())) for tenant_id, jobs in existing_jobs.items()]}"
# )
# try:
# tenants = get_all_tenant_ids()
# for tenant_id in tenants:
# try:
# logger.debug(
# f"Processing {'index attempts' if tenant_id is None else f'tenant {tenant_id}'}"
# )
# with get_session_with_tenant(tenant_id) as db_session:
# index_to_expire = check_index_swap(db_session=db_session)
# if index_to_expire and tenant_id and MULTI_TENANT:
# VespaIndex.delete_entries_by_tenant_id(
# tenant_id=tenant_id,
# index_name=index_to_expire.index_name,
# )
# if not MULTI_TENANT:
# search_settings = get_current_search_settings(db_session)
# if search_settings.provider_type is None:
# logger.notice(
# "Running a first inference to warm up embedding model"
# )
# embedding_model = EmbeddingModel.from_db_model(
# search_settings=search_settings,
# server_host=INDEXING_MODEL_SERVER_HOST,
# server_port=INDEXING_MODEL_SERVER_PORT,
# )
# warm_up_bi_encoder(embedding_model=embedding_model)
# logger.notice("First inference complete.")
# tenant_jobs = existing_jobs.get(tenant_id, {})
# tenant_jobs = cleanup_indexing_jobs(
# existing_jobs=tenant_jobs, tenant_id=tenant_id
# )
# create_indexing_jobs(existing_jobs=tenant_jobs, tenant_id=tenant_id)
# tenant_jobs = kickoff_indexing_jobs(
# existing_jobs=tenant_jobs,
# client=client_primary,
# secondary_client=client_secondary,
# tenant_id=tenant_id,
# )
# existing_jobs[tenant_id] = tenant_jobs
# except Exception as e:
# logger.exception(
# f"Failed to process tenant {tenant_id or 'default'}: {e}"
# )
# except Exception as e:
# logger.exception(f"Failed to run update due to {e}")
# sleep_time = delay - (time.time() - start)
# if sleep_time > 0:
# time.sleep(sleep_time)
# def update__main() -> None:
# set_is_ee_based_on_env_variable()
# # initialize the Postgres connection pool
# SqlEngine.set_app_name(POSTGRES_INDEXER_APP_NAME)
# logger.notice("Starting indexing service")
# update_loop()
# if __name__ == "__main__":
# update__main()
import logging
import time
from datetime import datetime
import dask
from dask.distributed import Client
from dask.distributed import Future
from distributed import LocalCluster
from sqlalchemy.orm import Session
from danswer.background.indexing.dask_utils import ResourceLogger
from danswer.background.indexing.job_client import SimpleJob
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
from danswer.db.connector import fetch_connectors
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import init_sqlalchemy_engine
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_inprogress_index_attempts
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import get_not_started_index_attempts
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.db.swap_index import check_index_swap
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import LOG_LEVEL
from shared_configs.configs import MODEL_SERVER_PORT
logger = setup_logger()
# If the indexing dies, it's most likely due to resource constraints,
# restarting just delays the eventual failure, not useful to the user
dask.config.set({"distributed.scheduler.allowed-failures": 0})
_UNEXPECTED_STATE_FAILURE_REASON = (
"Stopped mid run, likely due to the background process being killed"
)
def _should_create_new_indexing(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
model: EmbeddingModel,
secondary_index_building: bool,
db_session: Session,
) -> bool:
connector = cc_pair.connector
# User can still manually create single indexing attempts via the UI for the
# currently in use index
if DISABLE_INDEX_UPDATE_ON_SWAP:
if model.status == IndexModelStatus.PRESENT and secondary_index_building:
return False
# When switching over models, always index at least once
if model.status == IndexModelStatus.FUTURE:
if last_index:
# No new index if the last index attempt succeeded
# Once is enough. The model will never be able to swap otherwise.
if last_index.status == IndexingStatus.SUCCESS:
return False
# No new index if the last index attempt is waiting to start
if last_index.status == IndexingStatus.NOT_STARTED:
return False
# No new index if the last index attempt is running
if last_index.status == IndexingStatus.IN_PROGRESS:
return False
else:
if connector.id == 0: # Ingestion API
return False
return True
# If the connector is paused or is the ingestion API, don't index
# NOTE: during an embedding model switch over, the following logic
# is bypassed by the above check for a future model
if cc_pair.status == ConnectorCredentialPairStatus.PAUSED or connector.id == 0:
return False
if not last_index:
return True
if connector.refresh_freq is None:
return False
# Only one scheduled/ongoing job per connector at a time
# this prevents cases where
# (1) the "latest" index_attempt is scheduled so we show
# that in the UI despite another index_attempt being in-progress
# (2) multiple scheduled index_attempts at a time
if (
last_index.status == IndexingStatus.NOT_STARTED
or last_index.status == IndexingStatus.IN_PROGRESS
):
return False
current_db_time = get_db_current_time(db_session)
time_since_index = current_db_time - last_index.time_updated
return time_since_index.total_seconds() >= connector.refresh_freq
def _is_indexing_job_marked_as_finished(index_attempt: IndexAttempt | None) -> bool:
if index_attempt is None:
return False
return (
index_attempt.status == IndexingStatus.FAILED
or index_attempt.status == IndexingStatus.SUCCESS
)
def _mark_run_failed(
db_session: Session, index_attempt: IndexAttempt, failure_reason: str
) -> None:
"""Marks the `index_attempt` row as failed + updates the `
connector_credential_pair` to reflect that the run failed"""
logger.warning(
f"Marking in-progress attempt 'connector: {index_attempt.connector_credential_pair.connector_id}, "
f"credential: {index_attempt.connector_credential_pair.credential_id}' as failed due to {failure_reason}"
)
mark_attempt_failed(
index_attempt=index_attempt,
db_session=db_session,
failure_reason=failure_reason,
)
"""Main funcs"""
def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
"""Creates new indexing jobs for each connector / credential pair which is:
1. Enabled
2. `refresh_frequency` time has passed since the last indexing run for this pair
3. There is not already an ongoing indexing attempt for this pair
"""
with Session(get_sqlalchemy_engine()) as db_session:
ongoing: set[tuple[int | None, int]] = set()
for attempt_id in existing_jobs:
attempt = get_index_attempt(
db_session=db_session, index_attempt_id=attempt_id
)
if attempt is None:
logger.error(
f"Unable to find IndexAttempt for ID '{attempt_id}' when creating "
"indexing jobs"
)
continue
ongoing.add(
(
attempt.connector_credential_pair_id,
attempt.embedding_model_id,
)
)
embedding_models = [get_current_db_embedding_model(db_session)]
secondary_embedding_model = get_secondary_db_embedding_model(db_session)
if secondary_embedding_model is not None:
embedding_models.append(secondary_embedding_model)
all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair in all_connector_credential_pairs:
for model in embedding_models:
# Check if there is an ongoing indexing attempt for this connector credential pair
if (cc_pair.id, model.id) in ongoing:
continue
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, model.id, db_session
)
if not _should_create_new_indexing(
cc_pair=cc_pair,
last_index=last_attempt,
model=model,
secondary_index_building=len(embedding_models) > 1,
db_session=db_session,
):
continue
create_index_attempt(cc_pair.id, model.id, db_session)
def cleanup_indexing_jobs(
existing_jobs: dict[int, Future | SimpleJob],
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
) -> dict[int, Future | SimpleJob]:
existing_jobs_copy = existing_jobs.copy()
# clean up completed jobs
with Session(get_sqlalchemy_engine()) as db_session:
for attempt_id, job in existing_jobs.items():
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=attempt_id
)
# do nothing for ongoing jobs that haven't been stopped
if not job.done() and not _is_indexing_job_marked_as_finished(
index_attempt
):
continue
if job.status == "error":
logger.error(job.exception())
job.release()
del existing_jobs_copy[attempt_id]
if not index_attempt:
logger.error(
f"Unable to find IndexAttempt for ID '{attempt_id}' when cleaning "
"up indexing jobs"
)
continue
if (
index_attempt.status == IndexingStatus.IN_PROGRESS
or job.status == "error"
):
_mark_run_failed(
db_session=db_session,
index_attempt=index_attempt,
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
)
# clean up in-progress jobs that were never completed
connectors = fetch_connectors(db_session)
for connector in connectors:
in_progress_indexing_attempts = get_inprogress_index_attempts(
connector.id, db_session
)
for index_attempt in in_progress_indexing_attempts:
if index_attempt.id in existing_jobs:
# If index attempt is canceled, stop the run
if index_attempt.status == IndexingStatus.FAILED:
existing_jobs[index_attempt.id].cancel()
# check to see if the job has been updated in last `timeout_hours` hours, if not
# assume it to frozen in some bad state and just mark it as failed. Note: this relies
# on the fact that the `time_updated` field is constantly updated every
# batch of documents indexed
current_db_time = get_db_current_time(db_session=db_session)
time_since_update = current_db_time - index_attempt.time_updated
if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
existing_jobs[index_attempt.id].cancel()
_mark_run_failed(
db_session=db_session,
index_attempt=index_attempt,
failure_reason="Indexing run frozen - no updates in the last three hours. "
"The run will be re-attempted at next scheduled indexing time.",
)
else:
# If job isn't known, simply mark it as failed
_mark_run_failed(
db_session=db_session,
index_attempt=index_attempt,
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
)
return existing_jobs_copy
def kickoff_indexing_jobs(
existing_jobs: dict[int, Future | SimpleJob],
client: Client | SimpleJobClient,
secondary_client: Client | SimpleJobClient,
) -> dict[int, Future | SimpleJob]:
existing_jobs_copy = existing_jobs.copy()
engine = get_sqlalchemy_engine()
# Don't include jobs waiting in the Dask queue that just haven't started running
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
with Session(engine) as db_session:
# get_not_started_index_attempts orders its returned results from oldest to newest
# we must process attempts in a FIFO manner to prevent connector starvation
new_indexing_attempts = [
(attempt, attempt.embedding_model)
for attempt in get_not_started_index_attempts(db_session)
if attempt.id not in existing_jobs
]
logger.debug(f"Found {len(new_indexing_attempts)} new indexing task(s).")
if not new_indexing_attempts:
return existing_jobs
indexing_attempt_count = 0
for attempt, embedding_model in new_indexing_attempts:
use_secondary_index = (
embedding_model.status == IndexModelStatus.FUTURE
if embedding_model is not None
else False
)
if attempt.connector_credential_pair.connector is None:
logger.warning(
f"Skipping index attempt as Connector has been deleted: {attempt}"
)
with Session(engine) as db_session:
mark_attempt_failed(
attempt, db_session, failure_reason="Connector is null"
)
continue
if attempt.connector_credential_pair.credential is None:
logger.warning(
f"Skipping index attempt as Credential has been deleted: {attempt}"
)
with Session(engine) as db_session:
mark_attempt_failed(
attempt, db_session, failure_reason="Credential is null"
)
continue
if use_secondary_index:
run = secondary_client.submit(
run_indexing_entrypoint,
attempt.id,
global_version.get_is_ee_version(),
pure=False,
)
else:
run = client.submit(
run_indexing_entrypoint,
attempt.id,
global_version.get_is_ee_version(),
pure=False,
)
if run:
if indexing_attempt_count == 0:
logger.info(
f"Indexing dispatch starts: pending={len(new_indexing_attempts)}"
)
indexing_attempt_count += 1
secondary_str = " (secondary index)" if use_secondary_index else ""
logger.info(
f"Indexing dispatched{secondary_str}: "
f"attempt_id={attempt.id} "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.credential_id}'"
)
existing_jobs_copy[attempt.id] = run
if indexing_attempt_count > 0:
logger.info(
f"Indexing dispatch results: "
f"initial_pending={len(new_indexing_attempts)} "
f"started={indexing_attempt_count} "
f"remaining={len(new_indexing_attempts) - indexing_attempt_count}"
)
return existing_jobs_copy
def update_loop(
delay: int = 10,
num_workers: int = NUM_INDEXING_WORKERS,
num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
) -> None:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
check_index_swap(db_session=db_session)
db_embedding_model = get_current_db_embedding_model(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
if db_embedding_model.cloud_provider_id is None:
logger.debug("Running a first inference to warm up embedding model")
warm_up_bi_encoder(
embedding_model=db_embedding_model,
model_server_host=INDEXING_MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
client_primary: Client | SimpleJobClient
client_secondary: Client | SimpleJobClient
if DASK_JOB_CLIENT_ENABLED:
cluster_primary = LocalCluster(
n_workers=num_workers,
threads_per_worker=1,
# there are warning about high memory usage + "Event loop unresponsive"
# which are not relevant to us since our workers are expected to use a
# lot of memory + involve CPU intensive tasks that will not relinquish
# the event loop
silence_logs=logging.ERROR,
)
cluster_secondary = LocalCluster(
n_workers=num_secondary_workers,
threads_per_worker=1,
silence_logs=logging.ERROR,
)
client_primary = Client(cluster_primary)
client_secondary = Client(cluster_secondary)
if LOG_LEVEL.lower() == "debug":
client_primary.register_worker_plugin(ResourceLogger())
else:
client_primary = SimpleJobClient(n_workers=num_workers)
client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
existing_jobs: dict[int, Future | SimpleJob] = {}
while True:
start = time.time()
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
logger.debug(f"Running update, current UTC time: {start_time_utc}")
if existing_jobs:
# TODO: make this debug level once the "no jobs are being scheduled" issue is resolved
logger.debug(
"Found existing indexing jobs: "
f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}"
)
try:
with Session(get_sqlalchemy_engine()) as db_session:
check_index_swap(db_session)
existing_jobs = cleanup_indexing_jobs(existing_jobs=existing_jobs)
create_indexing_jobs(existing_jobs=existing_jobs)
existing_jobs = kickoff_indexing_jobs(
existing_jobs=existing_jobs,
client=client_primary,
secondary_client=client_secondary,
)
except Exception as e:
logger.exception(f"Failed to run update due to {e}")
sleep_time = delay - (time.time() - start)
if sleep_time > 0:
time.sleep(sleep_time)
def update__main() -> None:
set_is_ee_based_on_env_variable()
init_sqlalchemy_engine(POSTGRES_INDEXER_APP_NAME)
logger.info("Starting indexing service")
update_loop()
if __name__ == "__main__":
update__main()

View File

@@ -1,8 +1,6 @@
import re
from typing import cast
from uuid import UUID
from fastapi.datastructures import Headers
from sqlalchemy.orm import Session
from danswer.chat.models import CitationInfo
@@ -35,11 +33,10 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
def create_chat_chain(
chat_session_id: UUID,
chat_session_id: int,
db_session: Session,
prefetch_tool_calls: bool = True,
# Optional id at which we finish processing
stop_at_message_id: int | None = None,
parent_id: int | None = None,
) -> tuple[ChatMessage, list[ChatMessage]]:
"""Build the linear chain of messages without including the root message"""
mainline_messages: list[ChatMessage] = []
@@ -65,12 +62,7 @@ def create_chat_chain(
current_message: ChatMessage | None = root_message
while current_message is not None:
child_msg = current_message.latest_child_message
# Break if at the end of the chain
# or have reached the `final_id` of the submitted message
if not child_msg or (
stop_at_message_id and current_message.id == stop_at_message_id
):
if not child_msg or (parent_id and current_message.id == parent_id):
break
current_message = id_to_msg.get(child_msg)
@@ -168,31 +160,3 @@ def reorganize_citations(
new_citation_info[citation.citation_num] = citation
return new_answer, list(new_citation_info.values())
def extract_headers(
headers: dict[str, str] | Headers, pass_through_headers: list[str] | None
) -> dict[str, str]:
"""
Extract headers specified in pass_through_headers from input headers.
Handles both dict and FastAPI Headers objects, accounting for lowercase keys.
Args:
headers: Input headers as dict or Headers object.
Returns:
dict: Filtered headers based on pass_through_headers.
"""
if not pass_through_headers:
return {}
extracted_headers: dict[str, str] = {}
for key in pass_through_headers:
if key in headers:
extracted_headers[key] = headers[key]
else:
# fastapi makes all header keys lowercase, handling that here
lowercase_key = key.lower()
if lowercase_key in headers:
extracted_headers[lowercase_key] = headers[lowercase_key]
return extracted_headers

View File

@@ -6,6 +6,7 @@ from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import PERSONAS_YAML
from danswer.configs.chat_configs import PROMPTS_YAML
from danswer.db.document_set import get_or_create_document_set_by_name
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.db.models import Persona
@@ -17,32 +18,30 @@ from danswer.db.persona import upsert_prompt
from danswer.search.enums import RecencyBiasSetting
def load_prompts_from_yaml(
db_session: Session, prompts_yaml: str = PROMPTS_YAML
) -> None:
def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
with open(prompts_yaml, "r") as file:
data = yaml.safe_load(file)
all_prompts = data.get("prompts", [])
for prompt in all_prompts:
upsert_prompt(
user=None,
prompt_id=prompt.get("id"),
name=prompt["name"],
description=prompt["description"].strip(),
system_prompt=prompt["system"].strip(),
task_prompt=prompt["task"].strip(),
include_citations=prompt["include_citations"],
datetime_aware=prompt.get("datetime_aware", True),
default_prompt=True,
personas=None,
db_session=db_session,
commit=True,
)
with Session(get_sqlalchemy_engine()) as db_session:
for prompt in all_prompts:
upsert_prompt(
user=None,
prompt_id=prompt.get("id"),
name=prompt["name"],
description=prompt["description"].strip(),
system_prompt=prompt["system"].strip(),
task_prompt=prompt["task"].strip(),
include_citations=prompt["include_citations"],
datetime_aware=prompt.get("datetime_aware", True),
default_prompt=True,
personas=None,
db_session=db_session,
commit=True,
)
def load_personas_from_yaml(
db_session: Session,
personas_yaml: str = PERSONAS_YAML,
default_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
) -> None:
@@ -50,117 +49,117 @@ def load_personas_from_yaml(
data = yaml.safe_load(file)
all_personas = data.get("personas", [])
for persona in all_personas:
doc_set_names = persona["document_sets"]
doc_sets: list[DocumentSetDBModel] = [
get_or_create_document_set_by_name(db_session, name)
for name in doc_set_names
]
# Assume if user hasn't set any document sets for the persona, the user may want
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
# the document sets for the persona
doc_set_ids: list[int] | None = None
if doc_sets:
doc_set_ids = [doc_set.id for doc_set in doc_sets]
else:
doc_set_ids = None
prompt_ids: list[int] | None = None
prompt_set_names = persona["prompts"]
if prompt_set_names:
prompts: list[PromptDBModel | None] = [
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
for prompt_name in prompt_set_names
with Session(get_sqlalchemy_engine()) as db_session:
for persona in all_personas:
doc_set_names = persona["document_sets"]
doc_sets: list[DocumentSetDBModel] = [
get_or_create_document_set_by_name(db_session, name)
for name in doc_set_names
]
if any([prompt is None for prompt in prompts]):
raise ValueError("Invalid Persona configs, not all prompts exist")
if prompts:
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
# Assume if user hasn't set any document sets for the persona, the user may want
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
# the document sets for the persona
doc_set_ids: list[int] | None = None
if doc_sets:
doc_set_ids = [doc_set.id for doc_set in doc_sets]
else:
doc_set_ids = None
p_id = persona.get("id")
tool_ids = []
if persona.get("image_generation"):
image_gen_tool = (
db_session.query(ToolDBModel)
.filter(ToolDBModel.name == "ImageGenerationTool")
prompt_ids: list[int] | None = None
prompt_set_names = persona["prompts"]
if prompt_set_names:
prompts: list[PromptDBModel | None] = [
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
for prompt_name in prompt_set_names
]
if any([prompt is None for prompt in prompts]):
raise ValueError("Invalid Persona configs, not all prompts exist")
if prompts:
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
p_id = persona.get("id")
tool_ids = []
if persona.get("image_generation"):
image_gen_tool = (
db_session.query(ToolDBModel)
.filter(ToolDBModel.name == "ImageGenerationTool")
.first()
)
if image_gen_tool:
tool_ids.append(image_gen_tool.id)
llm_model_provider_override = persona.get("llm_model_provider_override")
llm_model_version_override = persona.get("llm_model_version_override")
# Set specific overrides for image generation persona
if persona.get("image_generation"):
llm_model_version_override = "gpt-4o"
existing_persona = (
db_session.query(Persona)
.filter(Persona.name == persona["name"])
.first()
)
if image_gen_tool:
tool_ids.append(image_gen_tool.id)
llm_model_provider_override = persona.get("llm_model_provider_override")
llm_model_version_override = persona.get("llm_model_version_override")
# Set specific overrides for image generation persona
if persona.get("image_generation"):
llm_model_version_override = "gpt-4o"
existing_persona = (
db_session.query(Persona).filter(Persona.name == persona["name"]).first()
)
upsert_persona(
user=None,
persona_id=(-1 * p_id) if p_id is not None else None,
name=persona["name"],
description=persona["description"],
num_chunks=persona.get("num_chunks")
if persona.get("num_chunks") is not None
else default_chunks,
llm_relevance_filter=persona.get("llm_relevance_filter"),
starter_messages=persona.get("starter_messages"),
llm_filter_extraction=persona.get("llm_filter_extraction"),
icon_shape=persona.get("icon_shape"),
icon_color=persona.get("icon_color"),
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
prompt_ids=prompt_ids,
document_set_ids=doc_set_ids,
tool_ids=tool_ids,
builtin_persona=True,
is_public=True,
display_priority=existing_persona.display_priority
if existing_persona is not None
else persona.get("display_priority"),
is_visible=existing_persona.is_visible
if existing_persona is not None
else persona.get("is_visible"),
db_session=db_session,
)
upsert_persona(
user=None,
persona_id=(-1 * p_id) if p_id is not None else None,
name=persona["name"],
description=persona["description"],
num_chunks=persona.get("num_chunks")
if persona.get("num_chunks") is not None
else default_chunks,
llm_relevance_filter=persona.get("llm_relevance_filter"),
starter_messages=persona.get("starter_messages"),
llm_filter_extraction=persona.get("llm_filter_extraction"),
icon_shape=persona.get("icon_shape"),
icon_color=persona.get("icon_color"),
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
prompt_ids=prompt_ids,
document_set_ids=doc_set_ids,
tool_ids=tool_ids,
default_persona=True,
is_public=True,
display_priority=existing_persona.display_priority
if existing_persona is not None
else persona.get("display_priority"),
is_visible=existing_persona.is_visible
if existing_persona is not None
else persona.get("is_visible"),
db_session=db_session,
)
def load_input_prompts_from_yaml(
db_session: Session, input_prompts_yaml: str = INPUT_PROMPT_YAML
) -> None:
def load_input_prompts_from_yaml(input_prompts_yaml: str = INPUT_PROMPT_YAML) -> None:
with open(input_prompts_yaml, "r") as file:
data = yaml.safe_load(file)
all_input_prompts = data.get("input_prompts", [])
for input_prompt in all_input_prompts:
# If these prompts are deleted (which is a hard delete in the DB), on server startup
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
insert_input_prompt_if_not_exists(
user=None,
input_prompt_id=input_prompt.get("id"),
prompt=input_prompt["prompt"],
content=input_prompt["content"],
is_public=input_prompt["is_public"],
active=input_prompt.get("active", True),
db_session=db_session,
commit=True,
)
with Session(get_sqlalchemy_engine()) as db_session:
for input_prompt in all_input_prompts:
# If these prompts are deleted (which is a hard delete in the DB), on server startup
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
insert_input_prompt_if_not_exists(
user=None,
input_prompt_id=input_prompt.get("id"),
prompt=input_prompt["prompt"],
content=input_prompt["content"],
is_public=input_prompt["is_public"],
active=input_prompt.get("active", True),
db_session=db_session,
commit=True,
)
def load_chat_yamls(
db_session: Session,
prompt_yaml: str = PROMPTS_YAML,
personas_yaml: str = PERSONAS_YAML,
input_prompts_yaml: str = INPUT_PROMPT_YAML,
) -> None:
load_prompts_from_yaml(db_session, prompt_yaml)
load_personas_from_yaml(db_session, personas_yaml)
load_input_prompts_from_yaml(db_session, input_prompts_yaml)
load_prompts_from_yaml(prompt_yaml)
load_personas_from_yaml(personas_yaml)
load_input_prompts_from_yaml(input_prompts_yaml)

View File

@@ -1,6 +1,5 @@
from collections.abc import Iterator
from datetime import datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel
@@ -10,7 +9,6 @@ from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import RetrievalDocs
from danswer.search.models import SearchResponse
from danswer.tools.custom.base_tool_types import ToolResultType
class LlmDoc(BaseModel):
@@ -36,35 +34,16 @@ class QADocsResponse(RetrievalDocs):
applied_time_cutoff: datetime | None
recency_bias_multiplier: float
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().dict(*args, **kwargs) # type: ignore
initial_dict["applied_time_cutoff"] = (
self.applied_time_cutoff.isoformat() if self.applied_time_cutoff else None
)
return initial_dict
class StreamStopReason(Enum):
CONTEXT_LENGTH = "context_length"
CANCELLED = "cancelled"
class StreamStopInfo(BaseModel):
stop_reason: StreamStopReason
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
data = super().model_dump(mode="json", *args, **kwargs) # type: ignore
data["stop_reason"] = self.stop_reason.name
return data
class LLMRelevanceFilterResponse(BaseModel):
llm_selected_doc_indices: list[int]
class FinalUsedContextDocsResponse(BaseModel):
final_context_docs: list[LlmDoc]
relevant_chunk_indices: list[int]
class RelevanceAnalysis(BaseModel):
@@ -85,6 +64,10 @@ class DocumentRelevance(BaseModel):
relevance_summaries: dict[str, RelevanceAnalysis]
class Delimiter(BaseModel):
delimiter: bool
class DanswerAnswerPiece(BaseModel):
# A small piece of a complete answer. Used for streaming back answers.
answer_piece: str | None # if None, specifies the end of an Answer
@@ -97,16 +80,6 @@ class CitationInfo(BaseModel):
document_id: str
class AllCitations(BaseModel):
citations: list[CitationInfo]
# This is a mapping of the citation number to the document index within
# the result search doc set
class MessageSpecificCitations(BaseModel):
citation_map: dict[int, int]
class MessageResponseIDInfo(BaseModel):
user_message_id: int | None
reserved_assistant_message_id: int
@@ -152,7 +125,7 @@ class QAResponse(SearchResponse, DanswerAnswer):
predicted_flow: QueryFlow
predicted_search: SearchType
eval_res_valid: bool | None = None
llm_selected_doc_indices: list[int] | None = None
llm_chunks_indices: list[int] | None = None
error_msg: str | None = None
@@ -161,7 +134,7 @@ class ImageGenerationDisplay(BaseModel):
class CustomToolResponse(BaseModel):
response: ToolResultType
response: dict
tool_name: str
@@ -173,7 +146,7 @@ AnswerQuestionPossibleReturn = (
| ImageGenerationDisplay
| CustomToolResponse
| StreamingError
| StreamStopInfo
| Delimiter
)

View File

@@ -1,4 +1,3 @@
import traceback
from collections.abc import Callable
from collections.abc import Iterator
from functools import partial
@@ -7,21 +6,15 @@ from typing import cast
from sqlalchemy.orm import Session
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.models import AllCitations
from danswer.chat.models import CitationInfo
from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import FinalUsedContextDocsResponse
from danswer.chat.models import Delimiter
from danswer.chat.models import ImageGenerationDisplay
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import MessageResponseIDInfo
from danswer.chat.models import MessageSpecificCitations
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
@@ -39,13 +32,13 @@ from danswer.db.chat import get_or_create_root_message
from danswer.db.chat import reserve_message_id
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
from danswer.db.persona import get_persona_by_id
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.factory import get_default_document_index
from danswer.file_store.models import ChatFileType
from danswer.file_store.models import FileDescriptor
@@ -77,9 +70,7 @@ from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from danswer.tools.custom.custom_tool import CustomToolCallSummary
from danswer.tools.force import ForceUseTool
@@ -94,27 +85,24 @@ from danswer.tools.internet_search.internet_search_tool import (
)
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.tool_runner import ToolCallMetadata
from danswer.tools.utils import compute_all_tool_tokens
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.headers import header_dict_to_header_list
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
logger = setup_logger()
def _translate_citations(
def translate_citations(
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
) -> MessageSpecificCitations:
) -> dict[int, int]:
"""Always cites the first instance of the document_id, assumes the db_docs
are sorted in the order displayed in the UI"""
doc_id_to_saved_doc_id_map: dict[str, int] = {}
@@ -129,7 +117,7 @@ def _translate_citations(
citation.citation_num
] = doc_id_to_saved_doc_id_map[citation.document_id]
return MessageSpecificCitations(citation_map=citation_to_saved_doc_id_map)
return citation_to_saved_doc_id_map
def _handle_search_tool_response_summary(
@@ -251,15 +239,13 @@ ChatPacket = (
StreamingError
| QADocsResponse
| LLMRelevanceFilterResponse
| FinalUsedContextDocsResponse
| ChatMessageDetail
| DanswerAnswerPiece
| AllCitations
| CitationInfo
| ImageGenerationDisplay
| CustomToolResponse
| MessageSpecificCitations
| MessageResponseIDInfo
| Delimiter
)
ChatPacketStream = Iterator[ChatPacket]
@@ -277,9 +263,7 @@ def stream_chat_message_objects(
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
enforce_chat_session_id_for_search_docs: bool = True,
) -> ChatPacketStream:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
@@ -287,11 +271,6 @@ def stream_chat_message_objects(
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
# Currently surrounding context is not supported for chat
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
new_msg_req.chunks_above = 0
new_msg_req.chunks_below = 0
try:
user_id = user.id if user is not None else None
@@ -348,9 +327,9 @@ def stream_chat_message_objects(
Callable[[str], list[int]], llm_tokenizer.encode
)
search_settings = get_current_search_settings(db_session)
embedding_model = get_current_db_embedding_model(db_session)
document_index = get_default_document_index(
primary_index_name=search_settings.index_name, secondary_index_name=None
primary_index_name=embedding_model.index_name, secondary_index_name=None
)
# Every chat Session begins with an empty root message
@@ -371,7 +350,7 @@ def stream_chat_message_objects(
if new_msg_req.regenerate:
final_msg, history_msgs = create_chat_chain(
stop_at_message_id=parent_id,
parent_id=parent_id,
chat_session_id=chat_session_id,
db_session=db_session,
)
@@ -451,7 +430,6 @@ def stream_chat_message_objects(
chat_session=chat_session,
user_id=user_id,
db_session=db_session,
enforce_chat_session_id_for_search_docs=enforce_chat_session_id_for_search_docs,
)
# Generates full documents currently
@@ -483,6 +461,8 @@ def stream_chat_message_objects(
else default_num_chunks
),
max_window_percentage=max_document_percentage,
use_sections=new_msg_req.chunks_above > 0
or new_msg_req.chunks_below > 0,
)
reserved_message_id = reserve_message_id(
db_session=db_session,
@@ -497,17 +477,16 @@ def stream_chat_message_objects(
reserved_assistant_message_id=reserved_message_id,
)
overridden_model = (
alternate_model = (
new_msg_req.llm_override.model_version if new_msg_req.llm_override else None
)
# Cannot determine these without the LLM step or breaking out early
partial_response = partial(
create_new_chat_message,
chat_session_id=chat_session_id,
parent_message=final_msg,
prompt_id=prompt_id,
overridden_model=overridden_model,
alternate_model=alternate_model,
# message=,
# rephrased_query=,
# token_count=,
@@ -566,26 +545,7 @@ def stream_chat_message_objects(
and llm.config.api_key
and llm.config.model_provider == "openai"
):
img_generation_llm_config = LLMConfig(
model_provider=llm.config.model_provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=llm.config.api_key,
api_base=llm.config.api_base,
api_version=llm.config.api_version,
)
elif (
llm.config.model_provider == "azure"
and AZURE_DALLE_API_KEY is not None
):
img_generation_llm_config = LLMConfig(
model_provider="azure",
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
temperature=GEN_AI_TEMPERATURE,
api_key=AZURE_DALLE_API_KEY,
api_base=AZURE_DALLE_API_BASE,
api_version=AZURE_DALLE_API_VERSION,
)
img_generation_llm_config = llm.config
else:
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
@@ -604,7 +564,7 @@ def stream_chat_message_objects(
)
img_generation_llm_config = LLMConfig(
model_provider=openai_provider.provider,
model_name="dall-e-3",
model_name=openai_provider.default_model_name,
temperature=GEN_AI_TEMPERATURE,
api_key=openai_provider.api_key,
api_base=openai_provider.api_base,
@@ -616,7 +576,6 @@ def stream_chat_message_objects(
api_base=img_generation_llm_config.api_base,
api_version=img_generation_llm_config.api_version,
additional_headers=litellm_additional_headers,
model=img_generation_llm_config.model_name,
)
]
elif tool_cls.__name__ == InternetSearchTool.__name__:
@@ -635,18 +594,8 @@ def stream_chat_message_objects(
if db_tool_model.openapi_schema:
tool_dict[db_tool_model.id] = cast(
list[Tool],
build_custom_tools_from_openapi_schema_and_headers(
db_tool_model.openapi_schema,
dynamic_schema_info=DynamicSchemaInfo(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
),
custom_headers=(db_tool_model.custom_headers or [])
+ (
header_dict_to_header_list(
custom_tool_additional_headers or {}
)
),
build_custom_tools_from_openapi_schema(
db_tool_model.openapi_schema
),
)
@@ -661,6 +610,7 @@ def stream_chat_message_objects(
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
llm_provider, llm_model_name
)
tool_has_been_called = False # TODO remove
# LLM prompt building, response capturing, etc.
answer = Answer(
@@ -701,6 +651,8 @@ def stream_chat_message_objects(
for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse):
tool_has_been_called = True
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
(
qa_docs_response,
@@ -711,11 +663,9 @@ def stream_chat_message_objects(
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
dedupe_docs=(
retrieval_options.dedupe_docs
if retrieval_options
else False
),
dedupe_docs=retrieval_options.dedupe_docs
if retrieval_options
else False,
)
yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
@@ -738,14 +688,9 @@ def stream_chat_message_objects(
)
yield LLMRelevanceFilterResponse(
llm_selected_doc_indices=llm_indices
relevant_chunk_indices=llm_indices
)
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(
final_context_docs=packet.response
)
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], packet.response
@@ -778,89 +723,137 @@ def stream_chat_message_objects(
)
else:
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
yield cast(ChatPacket, packet)
logger.debug("Reached end of stream")
except ValueError as e:
logger.exception("Failed to process chat message.")
if isinstance(packet, Delimiter):
db_citations = None
error_msg = str(e)
yield StreamingError(error=error_msg)
db_session.rollback()
return
if reference_db_search_docs:
db_citations = translate_citations(
citations_list=answer.citations,
db_docs=reference_db_search_docs,
)
# Saving Gen AI answer and responding with message info
tool_name_to_tool_id: dict[str, int] = {}
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name] = tool_id
if tool_result is None:
tool_call = None
else:
tool_call = ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
gen_ai_response_message = partial_response(
message=answer.llm_answer,
rephrased_query=(
qa_docs_response.rephrased_query
if qa_docs_response
else None
),
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=db_citations,
error=None,
tool_call=tool_call,
)
db_session.commit() # actually save user / assistant message
msg_detail_response = translate_db_message_to_chat_message_detail(
gen_ai_response_message
)
yield msg_detail_response
yield Delimiter(delimiter=True)
partial_response = partial(
create_new_chat_message,
chat_session_id=chat_session_id,
parent_message=gen_ai_response_message,
prompt_id=prompt_id,
# message=,
# rephrased_query=,
# token_count=,
message_type=MessageType.ASSISTANT,
alternate_assistant_id=new_msg_req.alternate_assistant_id,
# error=,
# reference_docs=,
db_session=db_session,
commit=False,
)
else:
if isinstance(packet, ToolCallMetadata):
tool_result = packet
yield cast(ChatPacket, packet)
logger.debug("Reached end of stream")
except Exception as e:
logger.exception("Failed to process chat message.")
error_msg = str(e)
stack_trace = traceback.format_exc()
logger.exception(f"Failed to process chat message: {error_msg}")
client_error_msg = litellm_exception_to_error_msg(e, llm)
if llm.config.api_key and len(llm.config.api_key) > 2:
error_msg = error_msg.replace(llm.config.api_key, "[REDACTED_API_KEY]")
stack_trace = stack_trace.replace(llm.config.api_key, "[REDACTED_API_KEY]")
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
yield StreamingError(error=client_error_msg, stack_trace=error_msg)
db_session.rollback()
return
# Post-LLM answer processing
try:
message_specific_citations: MessageSpecificCitations | None = None
if reference_db_search_docs:
message_specific_citations = _translate_citations(
citations_list=answer.citations,
db_docs=reference_db_search_docs,
)
yield AllCitations(citations=answer.citations)
if not tool_has_been_called:
try:
db_citations = None
if reference_db_search_docs:
db_citations = translate_citations(
citations_list=answer.citations,
db_docs=reference_db_search_docs,
)
# Saving Gen AI answer and responding with message info
tool_name_to_tool_id: dict[str, int] = {}
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name] = tool_id
# Saving Gen AI answer and responding with message info
tool_name_to_tool_id = {}
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name] = tool_id
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
message=answer.llm_answer,
rephrased_query=(
qa_docs_response.rephrased_query if qa_docs_response else None
),
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=message_specific_citations.citation_map
if message_specific_citations
else None,
error=None,
tool_calls=(
[
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
message=answer.llm_answer,
rephrased_query=(
qa_docs_response.rephrased_query if qa_docs_response else None
),
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=db_citations,
error=None,
tool_call=ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
if tool_result
else []
),
)
else None,
)
logger.debug("Committing messages")
db_session.commit() # actually save user / assistant message
logger.debug("Committing messages")
db_session.commit() # actually save user / assistant message
msg_detail_response = translate_db_message_to_chat_message_detail(
gen_ai_response_message
)
msg_detail_response = translate_db_message_to_chat_message_detail(
gen_ai_response_message
)
yield msg_detail_response
except Exception as e:
error_msg = str(e)
logger.exception(error_msg)
yield msg_detail_response
except Exception as e:
error_msg = str(e)
logger.exception(error_msg)
# Frontend will erase whatever answer and show this instead
yield StreamingError(error="Failed to parse LLM output")
# Frontend will erase whatever answer and show this instead
yield StreamingError(error="Failed to parse LLM output")
@log_generator_function_time()
@@ -869,7 +862,6 @@ def stream_chat_message(
user: User | None,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
) -> Iterator[str]:
with get_session_context_manager() as db_session:
@@ -879,8 +871,7 @@ def stream_chat_message(
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
custom_tool_additional_headers=custom_tool_additional_headers,
is_connected=is_connected,
)
for obj in objects:
yield get_json_line(obj.model_dump())
yield get_json_line(obj.dict())

View File

@@ -42,7 +42,8 @@ prompts:
task: >
Generate an image based on the user's description.
Provide a detailed description of the generated image, including key elements, colors, and composition.
Provide a detailed description of the generated image, including key elements, colors, and composition.
If the request is not possible or appropriate, explain why and suggest alternatives.
datetime_aware: true

View File

@@ -1,4 +1,4 @@
from typing_extensions import TypedDict # noreorder
from typing import TypedDict
from pydantic import BaseModel

View File

@@ -53,6 +53,7 @@ MASK_CREDENTIAL_PREFIX = (
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
)
SESSION_EXPIRE_TIME_SECONDS = int(
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
) # 7 days
@@ -92,14 +93,6 @@ SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER
# If set, Danswer will listen to the `expires_at` returned by the identity
# provider (e.g. Okta, Google, etc.) and force the user to re-authenticate
# after this time has elapsed. Disabled since by default many auth providers
# have very short expiry times (e.g. 1 hour) which provide a poor user experience
TRACK_EXTERNAL_IDP_EXPIRY = (
os.environ.get("TRACK_EXTERNAL_IDP_EXPIRY", "").lower() == "true"
)
#####
# DB Configs
@@ -115,23 +108,16 @@ VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
VESPA_CONFIG_SERVER_HOST = os.environ.get("VESPA_CONFIG_SERVER_HOST") or VESPA_HOST
VESPA_PORT = os.environ.get("VESPA_PORT") or "8081"
VESPA_TENANT_PORT = os.environ.get("VESPA_TENANT_PORT") or "19071"
VESPA_CLOUD_URL = os.environ.get("VESPA_CLOUD_URL", "")
# The default below is for dockerized deployment
VESPA_DEPLOYMENT_ZIP = (
os.environ.get("VESPA_DEPLOYMENT_ZIP") or "/app/danswer/vespa-app.zip"
)
VESPA_CLOUD_CERT_PATH = os.environ.get("VESPA_CLOUD_CERT_PATH")
VESPA_CLOUD_KEY_PATH = os.environ.get("VESPA_CLOUD_KEY_PATH")
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
try:
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE", 16))
except ValueError:
INDEX_BATCH_SIZE = 16
# Below are intended to match the env variables names used by the official postgres docker image
# https://hub.docker.com/_/postgres
POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres"
@@ -140,15 +126,9 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
os.environ.get("POSTGRES_PASSWORD") or "password"
)
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
POSTGRES_API_SERVER_POOL_SIZE = int(
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40
)
POSTGRES_API_SERVER_POOL_OVERFLOW = int(
os.environ.get("POSTGRES_API_SERVER_POOL_OVERFLOW") or 10
)
# defaults to False
POSTGRES_POOL_PRE_PING = os.environ.get("POSTGRES_POOL_PRE_PING", "").lower() == "true"
@@ -161,43 +141,6 @@ try:
except ValueError:
POSTGRES_POOL_RECYCLE = POSTGRES_POOL_RECYCLE_DEFAULT
REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true"
REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
# Used for general redis things
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))
# Used by celery as broker and backend
REDIS_DB_NUMBER_CELERY_RESULT_BACKEND = int(
os.environ.get("REDIS_DB_NUMBER_CELERY_RESULT_BACKEND", 14)
)
REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15)) # broker
# will propagate to both our redis client as well as celery's redis client
REDIS_HEALTH_CHECK_INTERVAL = int(os.environ.get("REDIS_HEALTH_CHECK_INTERVAL", 60))
# our redis client only, not celery's
REDIS_POOL_MAX_CONNECTIONS = int(os.environ.get("REDIS_POOL_MAX_CONNECTIONS", 128))
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
# should be one of "required", "optional", or "none"
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "none")
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", None)
CELERY_RESULT_EXPIRES = int(os.environ.get("CELERY_RESULT_EXPIRES", 86400)) # seconds
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#broker-pool-limit
# Setting to None may help when there is a proxy in the way closing idle connections
CELERY_BROKER_POOL_LIMIT_DEFAULT = 10
try:
CELERY_BROKER_POOL_LIMIT = int(
os.environ.get("CELERY_BROKER_POOL_LIMIT", CELERY_BROKER_POOL_LIMIT_DEFAULT)
)
except ValueError:
CELERY_BROKER_POOL_LIMIT = CELERY_BROKER_POOL_LIMIT_DEFAULT
#####
# Connector Configs
#####
@@ -249,8 +192,8 @@ CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [
]
# Avoid to get archived pages
CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES = (
os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES", "").lower() == "true"
CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES = (
os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES", "").lower() == "true"
)
# Save pages labels as Danswer metadata tags
@@ -261,12 +204,7 @@ CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING = (
# Attachments exceeding this size will not be retrieved (in bytes)
CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD = int(
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD", 10 * 1024 * 1024)
)
# Attachments with more chars than this will not be indexed. This is to prevent extremely
# large files from freezing indexing. 200,000 is ~100 google doc pages.
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000)
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD", 50 * 1024 * 1024)
)
JIRA_CONNECTOR_LABELS_TO_SKIP = [
@@ -274,10 +212,6 @@ JIRA_CONNECTOR_LABELS_TO_SKIP = [
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
if ignored_tag
]
# Maximum size for Jira tickets in bytes (default: 100KB)
JIRA_CONNECTOR_MAX_TICKET_SIZE = int(
os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024)
)
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
@@ -301,7 +235,7 @@ ALLOW_SIMULTANEOUS_PRUNING = (
os.environ.get("ALLOW_SIMULTANEOUS_PRUNING", "").lower() == "true"
)
# This is the maximum rate at which documents are queried for a pruning job. 0 disables the limitation.
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE = int(
os.environ.get("MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE", 0)
)
@@ -361,14 +295,12 @@ INDEXING_SIZE_WARNING_THRESHOLD = int(
# 0 disables this behavior and is the default.
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL", 0))
# During an indexing attempt, specifies the number of batches which are allowed to
# exception without aborting the attempt.
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT", 0))
#####
# Miscellaneous
#####
# File based Key Value store no longer used
DYNAMIC_CONFIG_STORE = "PostgresBackedDynamicConfigStore"
JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
# used to allow the background indexing jobs to use a different embedding
# model server than the API server
@@ -389,10 +321,6 @@ LOG_VESPA_TIMING_INFORMATION = (
os.environ.get("LOG_VESPA_TIMING_INFORMATION", "").lower() == "true"
)
LOG_ENDPOINT_LATENCY = os.environ.get("LOG_ENDPOINT_LATENCY", "").lower() == "true"
LOG_POSTGRES_LATENCY = os.environ.get("LOG_POSTGRES_LATENCY", "").lower() == "true"
LOG_POSTGRES_CONN_COUNTS = (
os.environ.get("LOG_POSTGRES_CONN_COUNTS", "").lower() == "true"
)
# Anonymous usage telemetry
DISABLE_TELEMETRY = os.environ.get("DISABLE_TELEMETRY", "").lower() == "true"
@@ -406,11 +334,6 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
os.environ.get("CUSTOM_ANSWER_VALIDITY_CONDITIONS", "[]")
)
VESPA_REQUEST_TIMEOUT = int(os.environ.get("VESPA_REQUEST_TIMEOUT") or "5")
SYSTEM_RECURSION_LIMIT = int(os.environ.get("SYSTEM_RECURSION_LIMIT") or "1000")
PARSE_WITH_TRAFILATURA = os.environ.get("PARSE_WITH_TRAFILATURA", "").lower() == "true"
#####
# Enterprise Edition Configs
@@ -422,39 +345,3 @@ PARSE_WITH_TRAFILATURA = os.environ.get("PARSE_WITH_TRAFILATURA", "").lower() ==
ENTERPRISE_EDITION_ENABLED = (
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
)
# Azure DALL-E Configurations
AZURE_DALLE_API_VERSION = os.environ.get("AZURE_DALLE_API_VERSION")
AZURE_DALLE_API_KEY = os.environ.get("AZURE_DALLE_API_KEY")
AZURE_DALLE_API_BASE = os.environ.get("AZURE_DALLE_API_BASE")
AZURE_DALLE_DEPLOYMENT_NAME = os.environ.get("AZURE_DALLE_DEPLOYMENT_NAME")
# Cloud configuration
# Multi-tenancy configuration
MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
# Use managed Vespa (Vespa Cloud). If set, must also set VESPA_CLOUD_URL, VESPA_CLOUD_CERT_PATH and VESPA_CLOUD_KEY_PATH
MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true"
ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true"
# Security and authentication
SECRET_JWT_KEY = os.environ.get(
"SECRET_JWT_KEY", ""
) # Used for encryption of the JWT token for user's tenant context
DATA_PLANE_SECRET = os.environ.get(
"DATA_PLANE_SECRET", ""
) # Used for secure communication between the control and data plane
EXPECTED_API_KEY = os.environ.get(
"EXPECTED_API_KEY", ""
) # Additional security check for the control plane API
# API configuration
CONTROL_PLANE_API_BASE_URL = os.environ.get(
"CONTROL_PLANE_API_BASE_URL", "http://localhost:8082"
)
# JWT configuration
JWT_ALGORITHM = "HS256"

View File

@@ -31,9 +31,8 @@ FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
# For the highest matching base size chunk, how many chunks above and below do we pull in by default
# Note this is not in any of the deployment configs yet
# Currently only applies to search flow not chat
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 1)
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 1)
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 0)
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 0)
# Whether the LLM should be used to decide if a search would help given the chat history
DISABLE_LLM_CHOOSE_SEARCH = (
os.environ.get("DISABLE_LLM_CHOOSE_SEARCH", "").lower() == "true"
@@ -45,7 +44,7 @@ DISABLE_LLM_QUERY_REPHRASE = (
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.5)))
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.62)))
HYBRID_ALPHA_KEYWORD = max(
0, min(1, float(os.environ.get("HYBRID_ALPHA_KEYWORD") or 0.4))
)
@@ -54,7 +53,7 @@ HYBRID_ALPHA_KEYWORD = max(
# Content. This is to avoid cases where the Content is very relevant but it may not be clear
# if the title is separated out. Title is most of a "boost" than a separate field.
TITLE_CONTENT_RATIO = max(
0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.10))
0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.20))
)
# A list of languages passed to the LLM to rephase the query
@@ -83,15 +82,8 @@ DISABLE_LLM_DOC_RELEVANCE = (
# Stops streaming answers back to the UI if this pattern is seen:
STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None
# Set this to "true" to hard delete chats
# This will make chats unviewable by admins after a user deletes them
# As opposed to soft deleting them, which just hides them from non-admin users
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "").lower() == "true"
# The backend logic for this being True isn't fully supported yet
HARD_DELETE_CHATS = False
# Internet Search
BING_API_KEY = os.environ.get("BING_API_KEY") or None
# Enable in-house model for detecting connector-based filtering in queries
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)
VESPA_SEARCHER_THREADS = int(os.environ.get("VESPA_SEARCHER_THREADS") or 2)

View File

@@ -1,6 +1,3 @@
import platform
import socket
from enum import auto
from enum import Enum
SOURCE_TYPE = "source_type"
@@ -15,6 +12,10 @@ ID_SEPARATOR = ":;:"
DEFAULT_BOOST = 0
SESSION_KEY = "session"
# For tool calling
MAXIMUM_TOOL_CALL_SEQUENCE = 5
# For chunking/processing chunks
RETURN_SEPARATOR = "\n\r\n"
SECTION_SEPARATOR = "\n\n"
@@ -31,22 +32,14 @@ DISABLED_GEN_AI_MSG = (
"You can still use Danswer as a search engine."
)
# Prefix used for all tenant ids
TENANT_ID_PREFIX = "tenant_"
# Postgres connection constants for application_name
POSTGRES_WEB_APP_NAME = "web"
POSTGRES_INDEXER_APP_NAME = "indexer"
POSTGRES_CELERY_APP_NAME = "celery"
POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_CELERY_WORKER_APP_NAME = "celery_worker"
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"
POSTGRES_DEFAULT_SCHEMA = "public"
# API Keys
DANSWER_API_KEY_PREFIX = "API_KEY__"
@@ -56,7 +49,6 @@ UNNAMED_KEY_PLACEHOLDER = "Unnamed"
# Key-Value store keys
KV_REINDEX_KEY = "needs_reindexing"
KV_SEARCH_SETTINGS = "search_settings"
KV_UNSTRUCTURED_API_KEY = "unstructured_api_key"
KV_USER_STORE_KEY = "INVITED_USERS"
KV_NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
KV_CRED_KEY = "credential_id_{}"
@@ -68,23 +60,9 @@ KV_SLACK_BOT_TOKENS_CONFIG_KEY = "slack_bot_tokens_config_key"
KV_GEN_AI_KEY_CHECK_TIME = "genai_api_key_last_check_time"
KV_SETTINGS_KEY = "danswer_settings"
KV_CUSTOMER_UUID_KEY = "customer_uuid"
KV_INSTANCE_DOMAIN_KEY = "instance_domain"
KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings"
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_INDEXING_LOCK_TIMEOUT = 60 * 60 # 60 min
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
class DocumentSource(str, Enum):
# Special case, document passed in via Danswer APIs without specifying a source type
@@ -121,22 +99,15 @@ class DocumentSource(str, Enum):
CLICKUP = "clickup"
MEDIAWIKI = "mediawiki"
WIKIPEDIA = "wikipedia"
ASANA = "asana"
S3 = "s3"
R2 = "r2"
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
OCI_STORAGE = "oci_storage"
XENFORO = "xenforo"
NOT_APPLICABLE = "not_applicable"
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
class NotificationType(str, Enum):
REINDEX = "reindex"
PERSONA_SHARED = "persona_shared"
TRIAL_ENDS_TWO_DAYS = "two_day_trial_ending" # 2 days left in trial
class BlobType(str, Enum):
@@ -161,15 +132,6 @@ class AuthType(str, Enum):
OIDC = "oidc"
SAML = "saml"
# google auth and basic
CLOUD = "cloud"
class SessionType(str, Enum):
CHAT = "Chat"
SEARCH = "Search"
SLACK = "Slack"
class QAFeedbackType(str, Enum):
LIKE = "like" # User likes the answer, used for metrics
@@ -203,44 +165,3 @@ class FileOrigin(str, Enum):
CONNECTOR = "connector"
GENERATED_REPORT = "generated_report"
OTHER = "other"
class PostgresAdvisoryLocks(Enum):
KOMBU_MESSAGE_CLEANUP_LOCK_ID = auto()
class DanswerCeleryQueues:
VESPA_METADATA_SYNC = "vespa_metadata_sync"
CONNECTOR_DELETION = "connector_deletion"
CONNECTOR_PRUNING = "connector_pruning"
CONNECTOR_INDEXING = "connector_indexing"
class DanswerRedisLocks:
PRIMARY_WORKER = "da_lock:primary_worker"
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat"
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
PRUNING_LOCK_PREFIX = "da_lock:pruning"
INDEXING_METADATA_PREFIX = "da_metadata:indexing"
class DanswerCeleryPriority(int, Enum):
HIGHEST = 0
HIGH = auto()
MEDIUM = auto()
LOW = auto()
LOWEST = auto()
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3
if platform.system() == "Darwin":
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPALIVE] = 60 # type: ignore
else:
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60 # type: ignore

View File

@@ -73,15 +73,3 @@ DANSWER_BOT_FEEDBACK_REMINDER = int(
DANSWER_BOT_REPHRASE_MESSAGE = (
os.environ.get("DANSWER_BOT_REPHRASE_MESSAGE", "").lower() == "true"
)
# DANSWER_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD is the number of
# responses DanswerBot can send in a given time period.
# Set to 0 to disable the limit.
DANSWER_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD = int(
os.environ.get("DANSWER_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD", "5000")
)
# DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS is the number
# of seconds until the response limit is reset.
DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS = int(
os.environ.get("DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS", "86400")
)

View File

@@ -39,13 +39,9 @@ SIM_SCORE_RANGE_HIGH = float(os.environ.get("SIM_SCORE_RANGE_HIGH") or 1.0)
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "search_query: ")
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "search_document: ")
# Purely an optimization, memory limitation consideration
# User's set embedding batch size overrides the default encoding batch sizes
EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE") or 0) or None
BATCH_SIZE_ENCODE_CHUNKS = EMBEDDING_BATCH_SIZE or 8
BATCH_SIZE_ENCODE_CHUNKS = 8
# don't send over too many chunks at once, as sending too many could cause timeouts
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES = EMBEDDING_BATCH_SIZE or 512
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES = 512
# For score display purposes, only way is to know the expected ranges
CROSS_ENCODER_RANGE_MAX = 1
CROSS_ENCODER_RANGE_MIN = 0
@@ -55,23 +51,37 @@ CROSS_ENCODER_RANGE_MIN = 0
# Generative AI Model Configs
#####
# NOTE: the 3 below should only be used for dev.
GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY")
# If changing GEN_AI_MODEL_PROVIDER or GEN_AI_MODEL_VERSION from the default,
# be sure to use one that is LiteLLM compatible:
# https://litellm.vercel.app/docs/providers/azure#completion---using-env-variables
# The provider is the prefix before / in the model argument
# Additionally Danswer supports GPT4All and custom request library based models
# Set GEN_AI_MODEL_PROVIDER to "custom" to use the custom requests approach
# Set GEN_AI_MODEL_PROVIDER to "gpt4all" to use gpt4all models running locally
GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai"
# If using Azure, it's the engine name, for example: Danswer
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION")
# For secondary flows like extracting filters or deciding if a chunk is useful, we don't need
# as powerful of a model as say GPT-4 so we can use an alternative that is faster and cheaper
FAST_GEN_AI_MODEL_VERSION = os.environ.get("FAST_GEN_AI_MODEL_VERSION")
# Override the auto-detection of LLM max context length
GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None
# Set this to be enough for an answer + quotes. Also used for Chat
# This is the minimum token context we will leave for the LLM to generate an answer
GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int(
os.environ.get("GEN_AI_NUM_RESERVED_OUTPUT_TOKENS") or 1024
# If the Generative AI model requires an API key for access, otherwise can leave blank
GEN_AI_API_KEY = (
os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY")) or None
)
# Typically, GenAI models nowadays are at least 4K tokens
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096
# API Base, such as (for Azure): https://danswer.openai.azure.com/
GEN_AI_API_ENDPOINT = os.environ.get("GEN_AI_API_ENDPOINT") or None
# API Version, such as (for Azure): 2023-09-15-preview
GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None
# LiteLLM custom_llm_provider
GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None
# Override the auto-detection of LLM max context length
GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None
# Set this to be enough for an answer + quotes. Also used for Chat
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024)
# Number of tokens from chat history to include at maximum
# 3000 should be enough context regardless of use, no need to include as much as possible
# as this drives up the cost unnecessarily

Some files were not shown because too many files have changed in this diff Show More