mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 00:05:47 +00:00
Compare commits
1 Commits
dropdown
...
access_tok
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d6c41f3663 |
47
.github/actions/custom-build-and-push/action.yml
vendored
47
.github/actions/custom-build-and-push/action.yml
vendored
@@ -32,20 +32,16 @@ inputs:
|
||||
description: 'Cache destinations'
|
||||
required: false
|
||||
retry-wait-time:
|
||||
description: 'Time to wait before attempt 2 in seconds'
|
||||
description: 'Time to wait before retry in seconds'
|
||||
required: false
|
||||
default: '60'
|
||||
retry-wait-time-2:
|
||||
description: 'Time to wait before attempt 3 in seconds'
|
||||
required: false
|
||||
default: '120'
|
||||
default: '5'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Build and push Docker image (Attempt 1 of 3)
|
||||
- name: Build and push Docker image (First Attempt)
|
||||
id: buildx1
|
||||
uses: docker/build-push-action@v6
|
||||
uses: docker/build-push-action@v5
|
||||
continue-on-error: true
|
||||
with:
|
||||
context: ${{ inputs.context }}
|
||||
@@ -58,17 +54,16 @@ runs:
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
|
||||
- name: Wait before attempt 2
|
||||
- name: Wait to retry
|
||||
if: steps.buildx1.outcome != 'success'
|
||||
run: |
|
||||
echo "First attempt failed. Waiting ${{ inputs.retry-wait-time }} seconds before retry..."
|
||||
sleep ${{ inputs.retry-wait-time }}
|
||||
shell: bash
|
||||
|
||||
- name: Build and push Docker image (Attempt 2 of 3)
|
||||
id: buildx2
|
||||
- name: Build and push Docker image (Retry Attempt)
|
||||
if: steps.buildx1.outcome != 'success'
|
||||
uses: docker/build-push-action@v6
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ${{ inputs.context }}
|
||||
file: ${{ inputs.file }}
|
||||
@@ -79,31 +74,3 @@ runs:
|
||||
tags: ${{ inputs.tags }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
|
||||
- name: Wait before attempt 3
|
||||
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success'
|
||||
run: |
|
||||
echo "Second attempt failed. Waiting ${{ inputs.retry-wait-time-2 }} seconds before retry..."
|
||||
sleep ${{ inputs.retry-wait-time-2 }}
|
||||
shell: bash
|
||||
|
||||
- name: Build and push Docker image (Attempt 3 of 3)
|
||||
id: buildx3
|
||||
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success'
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ${{ inputs.context }}
|
||||
file: ${{ inputs.file }}
|
||||
platforms: ${{ inputs.platforms }}
|
||||
pull: ${{ inputs.pull }}
|
||||
push: ${{ inputs.push }}
|
||||
load: ${{ inputs.load }}
|
||||
tags: ${{ inputs.tags }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
|
||||
- name: Report failure
|
||||
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success' && steps.buildx3.outcome != 'success'
|
||||
run: |
|
||||
echo "All attempts failed. Possible transient infrastucture issues? Try again later or inspect logs for details."
|
||||
shell: bash
|
||||
|
||||
30
.github/pull_request_template.md
vendored
30
.github/pull_request_template.md
vendored
@@ -6,24 +6,20 @@
|
||||
[Describe the tests you ran to verify your changes]
|
||||
|
||||
|
||||
## Accepted Risk (provide if relevant)
|
||||
N/A
|
||||
## Accepted Risk
|
||||
[Any know risks or failure modes to point out to reviewers]
|
||||
|
||||
|
||||
## Related Issue(s) (provide if relevant)
|
||||
N/A
|
||||
## Related Issue(s)
|
||||
[If applicable, link to the issue(s) this PR addresses]
|
||||
|
||||
|
||||
## Mental Checklist:
|
||||
- All of the automated tests pass
|
||||
- All PR comments are addressed and marked resolved
|
||||
- If there are migrations, they have been rebased to latest main
|
||||
- If there are new dependencies, they are added to the requirements
|
||||
- If there are new environment variables, they are added to all of the deployment methods
|
||||
- If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
|
||||
- Docker images build and basic functionalities work
|
||||
- Author has done a final read through of the PR right before merge
|
||||
|
||||
## Backporting (check the box to trigger backport action)
|
||||
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
|
||||
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
|
||||
## Checklist:
|
||||
- [ ] All of the automated tests pass
|
||||
- [ ] All PR comments are addressed and marked resolved
|
||||
- [ ] If there are migrations, they have been rebased to latest main
|
||||
- [ ] If there are new dependencies, they are added to the requirements
|
||||
- [ ] If there are new environment variables, they are added to all of the deployment methods
|
||||
- [ ] If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
|
||||
- [ ] Docker images build and basic functionalities work
|
||||
- [ ] Author has done a final read through of the PR right before merge
|
||||
|
||||
@@ -3,61 +3,61 @@ name: Build and Push Backend Image on Tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
- '*'
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'danswer/danswer-backend-cloud' || 'danswer/danswer-backend' }}
|
||||
REGISTRY_IMAGE: danswer/danswer-backend
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
# TODO: investigate a matrix build like the web container
|
||||
# TODO: investigate a matrix build like the web container
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Install build-essential
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential
|
||||
- name: Install build-essential
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential
|
||||
|
||||
- name: Backend Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
- name: Backend Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
with:
|
||||
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
trivyignores: ./backend/.trivyignore
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
trivyignores: ./backend/.trivyignore
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
name: Build and Push Cloud Web Image on Tag
|
||||
# Identical to the web container build, but with correct image tag and build args
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-web-server-cloud
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=${{ matrix.platform == 'linux/amd64' && '8cpu-linux-x64' || '8cpu-linux-arm64' }}
|
||||
- run-id=${{ github.run_id }}
|
||||
- tag=platform-${{ matrix.platform }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
platform:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
|
||||
steps:
|
||||
- name: Prepare
|
||||
run: |
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
tags: |
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: ${{ matrix.platform }}
|
||||
push: true
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
NEXT_PUBLIC_CLOUD_ENABLED=true
|
||||
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
mkdir -p /tmp/digests
|
||||
digest="${{ steps.build.outputs.digest }}"
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: digests-${{ env.PLATFORM_PAIR }}
|
||||
path: /tmp/digests/*
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
|
||||
merge:
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- build
|
||||
steps:
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: /tmp/digests
|
||||
pattern: digests-*
|
||||
merge-multiple: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create manifest list and push
|
||||
working-directory: /tmp/digests
|
||||
run: |
|
||||
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
|
||||
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
|
||||
|
||||
- name: Inspect image
|
||||
run: |
|
||||
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
@@ -3,53 +3,53 @@ name: Build and Push Model Server Image on Tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
- '*'
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'danswer/danswer-model-server-cloud' || 'danswer/danswer-model-server' }}
|
||||
REGISTRY_IMAGE: danswer/danswer-model-server
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Model Server Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
- name: Model Server Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
with:
|
||||
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
|
||||
@@ -11,11 +11,8 @@ env:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=${{ matrix.platform == 'linux/amd64' && '8cpu-linux-x64' || '8cpu-linux-arm64' }}
|
||||
- run-id=${{ github.run_id }}
|
||||
- tag=platform-${{ matrix.platform }}
|
||||
runs-on:
|
||||
group: ${{ matrix.platform == 'linux/amd64' && 'amd64-image-builders' || 'arm64-image-builders' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
|
||||
172
.github/workflows/hotfix-release-branches.yml
vendored
172
.github/workflows/hotfix-release-branches.yml
vendored
@@ -1,172 +0,0 @@
|
||||
# This workflow is intended to be manually triggered via the GitHub Action tab.
|
||||
# Given a hotfix branch, it will attempt to open a PR to all release branches and
|
||||
# by default auto merge them
|
||||
|
||||
name: Hotfix release branches
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
hotfix_commit:
|
||||
description: 'Hotfix commit hash'
|
||||
required: true
|
||||
hotfix_suffix:
|
||||
description: 'Hotfix branch suffix (e.g. hotfix/v0.8-{suffix})'
|
||||
required: true
|
||||
release_branch_pattern:
|
||||
description: 'Release branch pattern (regex)'
|
||||
required: true
|
||||
default: 'release/.*'
|
||||
auto_merge:
|
||||
description: 'Automatically merge the hotfix PRs'
|
||||
required: true
|
||||
type: choice
|
||||
default: 'true'
|
||||
options:
|
||||
- true
|
||||
- false
|
||||
|
||||
jobs:
|
||||
hotfix_release_branches:
|
||||
permissions: write-all
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# use a lower powered instance since this just does i/o to docker hub
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
|
||||
# needs RKUO_DEPLOY_KEY for write access to merge PR's
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Git user
|
||||
run: |
|
||||
git config user.name "Richard Kuo [bot]"
|
||||
git config user.email "rkuo[bot]@danswer.ai"
|
||||
|
||||
- name: Fetch All Branches
|
||||
run: |
|
||||
git fetch --all --prune
|
||||
|
||||
- name: Verify Hotfix Commit Exists
|
||||
run: |
|
||||
git rev-parse --verify "${{ github.event.inputs.hotfix_commit }}" || { echo "Commit not found: ${{ github.event.inputs.hotfix_commit }}"; exit 1; }
|
||||
|
||||
- name: Get Release Branches
|
||||
id: get_release_branches
|
||||
run: |
|
||||
BRANCHES=$(git branch -r | grep -E "${{ github.event.inputs.release_branch_pattern }}" | sed 's|origin/||' | tr -d ' ')
|
||||
if [ -z "$BRANCHES" ]; then
|
||||
echo "No release branches found matching pattern '${{ github.event.inputs.release_branch_pattern }}'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Found release branches:"
|
||||
echo "$BRANCHES"
|
||||
|
||||
# Join the branches into a single line separated by commas
|
||||
BRANCHES_JOINED=$(echo "$BRANCHES" | tr '\n' ',' | sed 's/,$//')
|
||||
|
||||
# Set the branches as an output
|
||||
echo "branches=$BRANCHES_JOINED" >> $GITHUB_OUTPUT
|
||||
|
||||
# notes on all the vagaries of wiring up automated PR's
|
||||
# https://github.com/peter-evans/create-pull-request/blob/main/docs/concepts-guidelines.md#triggering-further-workflow-runs
|
||||
# we must use a custom token for GH_TOKEN to trigger the subsequent PR checks
|
||||
- name: Create and Merge Pull Requests to Matching Release Branches
|
||||
env:
|
||||
HOTFIX_COMMIT: ${{ github.event.inputs.hotfix_commit }}
|
||||
HOTFIX_SUFFIX: ${{ github.event.inputs.hotfix_suffix }}
|
||||
AUTO_MERGE: ${{ github.event.inputs.auto_merge }}
|
||||
GH_TOKEN: ${{ secrets.RKUO_PERSONAL_ACCESS_TOKEN }}
|
||||
run: |
|
||||
# Get the branches from the previous step
|
||||
BRANCHES="${{ steps.get_release_branches.outputs.branches }}"
|
||||
|
||||
# Convert BRANCHES to an array
|
||||
IFS=$',' read -ra BRANCH_ARRAY <<< "$BRANCHES"
|
||||
|
||||
# Loop through each release branch and create and merge a PR
|
||||
for RELEASE_BRANCH in "${BRANCH_ARRAY[@]}"; do
|
||||
echo "Processing $RELEASE_BRANCH..."
|
||||
|
||||
# Parse out the release version by removing "release/" from the branch name
|
||||
RELEASE_VERSION=${RELEASE_BRANCH#release/}
|
||||
echo "Release version parsed: $RELEASE_VERSION"
|
||||
|
||||
HOTFIX_BRANCH="hotfix/${RELEASE_VERSION}-${HOTFIX_SUFFIX}"
|
||||
echo "Creating PR from $HOTFIX_BRANCH to $RELEASE_BRANCH"
|
||||
|
||||
# Checkout the release branch
|
||||
echo "Checking out $RELEASE_BRANCH"
|
||||
git checkout "$RELEASE_BRANCH"
|
||||
|
||||
# Create the new hotfix branch
|
||||
if git rev-parse --verify "$HOTFIX_BRANCH" >/dev/null 2>&1; then
|
||||
echo "Hotfix branch $HOTFIX_BRANCH already exists. Skipping branch creation."
|
||||
else
|
||||
echo "Branching $RELEASE_BRANCH to $HOTFIX_BRANCH"
|
||||
git checkout -b "$HOTFIX_BRANCH"
|
||||
fi
|
||||
|
||||
# Check if the hotfix commit is a merge commit
|
||||
if git rev-list --merges -n 1 "$HOTFIX_COMMIT" >/dev/null 2>&1; then
|
||||
# -m 1 uses the target branch as the base (which is what we want)
|
||||
echo "Hotfix commit $HOTFIX_COMMIT is a merge commit, using -m 1 for cherry-pick"
|
||||
CHERRY_PICK_CMD="git cherry-pick -m 1 $HOTFIX_COMMIT"
|
||||
else
|
||||
CHERRY_PICK_CMD="git cherry-pick $HOTFIX_COMMIT"
|
||||
fi
|
||||
|
||||
# Perform the cherry-pick
|
||||
echo "Executing: $CHERRY_PICK_CMD"
|
||||
eval "$CHERRY_PICK_CMD"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Cherry-pick failed for $HOTFIX_COMMIT on $HOTFIX_BRANCH. Aborting..."
|
||||
git cherry-pick --abort
|
||||
continue
|
||||
fi
|
||||
|
||||
# Push the hotfix branch to the remote
|
||||
echo "Pushing $HOTFIX_BRANCH..."
|
||||
git push origin "$HOTFIX_BRANCH"
|
||||
echo "Hotfix branch $HOTFIX_BRANCH created and pushed."
|
||||
|
||||
# Check if PR already exists
|
||||
EXISTING_PR=$(gh pr list --head "$HOTFIX_BRANCH" --base "$RELEASE_BRANCH" --state open --json number --jq '.[0].number')
|
||||
|
||||
if [ -n "$EXISTING_PR" ]; then
|
||||
echo "An open PR already exists: #$EXISTING_PR. Skipping..."
|
||||
continue
|
||||
fi
|
||||
|
||||
# Create a new PR and capture the output
|
||||
PR_OUTPUT=$(gh pr create --title "Merge $HOTFIX_BRANCH into $RELEASE_BRANCH" \
|
||||
--body "Automated PR to merge \`$HOTFIX_BRANCH\` into \`$RELEASE_BRANCH\`." \
|
||||
--head "$HOTFIX_BRANCH" --base "$RELEASE_BRANCH")
|
||||
|
||||
# Extract the URL from the output
|
||||
PR_URL=$(echo "$PR_OUTPUT" | grep -Eo 'https://github.com/[^ ]+')
|
||||
echo "Pull request created: $PR_URL"
|
||||
|
||||
# Extract PR number from URL
|
||||
PR_NUMBER=$(basename "$PR_URL")
|
||||
echo "Pull request created: $PR_NUMBER"
|
||||
|
||||
if [ "$AUTO_MERGE" == "true" ]; then
|
||||
echo "Attempting to merge pull request #$PR_NUMBER"
|
||||
|
||||
# Attempt to merge the PR
|
||||
gh pr merge "$PR_NUMBER" --merge --auto --delete-branch
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "Pull request #$PR_NUMBER merged successfully."
|
||||
else
|
||||
# Optionally, handle the error or continue
|
||||
echo "Failed to merge pull request #$PR_NUMBER."
|
||||
fi
|
||||
fi
|
||||
done
|
||||
23
.github/workflows/nightly-close-stale-issues.yml
vendored
23
.github/workflows/nightly-close-stale-issues.yml
vendored
@@ -1,23 +0,0 @@
|
||||
name: 'Nightly - Close stale issues and PRs'
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 11 * * *' # Runs every day at 3 AM PST / 4 AM PDT / 11 AM UTC
|
||||
|
||||
permissions:
|
||||
# contents: write # only for delete-branch option
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@v9
|
||||
with:
|
||||
stale-issue-message: 'This issue is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
stale-pr-message: 'This PR is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
close-issue-message: 'This issue was closed because it has been stalled for 90 days with no activity.'
|
||||
close-pr-message: 'This PR was closed because it has been stalled for 90 days with no activity.'
|
||||
days-before-stale: 75
|
||||
# days-before-close: 90 # uncomment after we test stale behavior
|
||||
|
||||
76
.github/workflows/nightly-scan-licenses.yml
vendored
76
.github/workflows/nightly-scan-licenses.yml
vendored
@@ -1,76 +0,0 @@
|
||||
# Scan for problematic software licenses
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
|
||||
name: 'Nightly - Scan licenses'
|
||||
on:
|
||||
# schedule:
|
||||
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
security-events: write
|
||||
|
||||
jobs:
|
||||
scan-licenses:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
|
||||
- name: Get explicit and transitive dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
pip freeze > requirements-all.txt
|
||||
|
||||
- name: Check python
|
||||
id: license_check_report
|
||||
uses: pilosus/action-pip-license-checker@v2
|
||||
with:
|
||||
requirements: 'requirements-all.txt'
|
||||
fail: 'Copyleft'
|
||||
exclude: '(?i)^(pylint|aio[-_]*).*'
|
||||
|
||||
- name: Print report
|
||||
if: ${{ always() }}
|
||||
run: echo "${{ steps.license_check_report.outputs.report }}"
|
||||
|
||||
- name: Install npm dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Run Trivy vulnerability scanner in repo mode
|
||||
uses: aquasecurity/trivy-action@0.28.0
|
||||
with:
|
||||
scan-type: fs
|
||||
scanners: license
|
||||
format: table
|
||||
# format: sarif
|
||||
# output: trivy-results.sarif
|
||||
severity: HIGH,CRITICAL
|
||||
|
||||
# - name: Upload Trivy scan results to GitHub Security tab
|
||||
# uses: github/codeql-action/upload-sarif@v3
|
||||
# with:
|
||||
# sarif_file: trivy-results.sarif
|
||||
69
.github/workflows/pr-Integration-tests.yml
vendored
69
.github/workflows/pr-Integration-tests.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
@@ -85,58 +85,7 @@ jobs:
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
# Start containers for multi-tenant tests
|
||||
- name: Start Docker containers for multi-tenant tests
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
MULTI_TENANT=true \
|
||||
AUTH_TYPE=basic \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
id: start_docker_multi_tenant
|
||||
|
||||
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
|
||||
- name: Run Multi-Tenant Integration Tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network danswer-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AUTH_TYPE=cloud \
|
||||
-e MULTI_TENANT=true \
|
||||
danswer/danswer-integration:test \
|
||||
/app/tests/integration/multitenant_tests
|
||||
continue-on-error: true
|
||||
id: run_multitenant_tests
|
||||
|
||||
- name: Check multi-tenant test results
|
||||
run: |
|
||||
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||
echo "Integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
- name: Stop multi-tenant Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
|
||||
- name: Start Docker containers
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
@@ -181,7 +130,7 @@ jobs:
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Run Standard Integration Tests
|
||||
- name: Run integration tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network danswer-stack_default \
|
||||
@@ -196,9 +145,7 @@ jobs:
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
danswer/danswer-integration:test \
|
||||
/app/tests/integration/tests \
|
||||
/app/tests/integration/connector_job_tests
|
||||
danswer/danswer-integration:test
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
@@ -211,22 +158,16 @@ jobs:
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
# save before stopping the containers so the logs can be captured
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
|
||||
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Stop Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: docker-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
124
.github/workflows/pr-backport-autotrigger.yml
vendored
124
.github/workflows/pr-backport-autotrigger.yml
vendored
@@ -1,124 +0,0 @@
|
||||
name: Backport on Merge
|
||||
|
||||
# Note this workflow does not trigger the builds, be sure to manually tag the branches to trigger the builds
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [closed] # Later we check for merge so only PRs that go in can get backported
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
actions: write
|
||||
|
||||
jobs:
|
||||
backport:
|
||||
if: github.event.pull_request.merged == true
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.YUHONG_GH_ACTIONS }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Git user
|
||||
run: |
|
||||
git config user.name "Richard Kuo [bot]"
|
||||
git config user.email "rkuo[bot]@danswer.ai"
|
||||
git fetch --prune
|
||||
|
||||
- name: Check for Backport Checkbox
|
||||
id: checkbox-check
|
||||
run: |
|
||||
PR_BODY="${{ github.event.pull_request.body }}"
|
||||
if [[ "$PR_BODY" == *"[x] This PR should be backported"* ]]; then
|
||||
echo "backport=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "backport=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: List and sort release branches
|
||||
id: list-branches
|
||||
run: |
|
||||
git fetch --all --tags
|
||||
BRANCHES=$(git for-each-ref --format='%(refname:short)' refs/remotes/origin/release/* | sed 's|origin/release/||' | sort -Vr)
|
||||
BETA=$(echo "$BRANCHES" | head -n 1)
|
||||
STABLE=$(echo "$BRANCHES" | head -n 2 | tail -n 1)
|
||||
echo "beta=release/$BETA" >> $GITHUB_OUTPUT
|
||||
echo "stable=release/$STABLE" >> $GITHUB_OUTPUT
|
||||
# Fetch latest tags for beta and stable
|
||||
LATEST_BETA_TAG=$(git tag -l "v[0-9]*.[0-9]*.[0-9]*-beta.[0-9]*" | grep -E "^v[0-9]+\.[0-9]+\.[0-9]+-beta\.[0-9]+$" | grep -v -- "-cloud" | sort -Vr | head -n 1)
|
||||
LATEST_STABLE_TAG=$(git tag -l "v[0-9]*.[0-9]*.[0-9]*" | grep -E "^v[0-9]+\.[0-9]+\.[0-9]+$" | sort -Vr | head -n 1)
|
||||
|
||||
# Handle case where no beta tags exist
|
||||
if [[ -z "$LATEST_BETA_TAG" ]]; then
|
||||
NEW_BETA_TAG="v1.0.0-beta.1"
|
||||
else
|
||||
NEW_BETA_TAG=$(echo $LATEST_BETA_TAG | awk -F '[.-]' '{print $1 "." $2 "." $3 "-beta." ($NF+1)}')
|
||||
fi
|
||||
|
||||
# Increment latest stable tag
|
||||
NEW_STABLE_TAG=$(echo $LATEST_STABLE_TAG | awk -F '.' '{print $1 "." $2 "." ($3+1)}')
|
||||
echo "latest_beta_tag=$LATEST_BETA_TAG" >> $GITHUB_OUTPUT
|
||||
echo "latest_stable_tag=$LATEST_STABLE_TAG" >> $GITHUB_OUTPUT
|
||||
echo "new_beta_tag=$NEW_BETA_TAG" >> $GITHUB_OUTPUT
|
||||
echo "new_stable_tag=$NEW_STABLE_TAG" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Echo branch and tag information
|
||||
run: |
|
||||
echo "Beta branch: ${{ steps.list-branches.outputs.beta }}"
|
||||
echo "Stable branch: ${{ steps.list-branches.outputs.stable }}"
|
||||
echo "Latest beta tag: ${{ steps.list-branches.outputs.latest_beta_tag }}"
|
||||
echo "Latest stable tag: ${{ steps.list-branches.outputs.latest_stable_tag }}"
|
||||
echo "New beta tag: ${{ steps.list-branches.outputs.new_beta_tag }}"
|
||||
echo "New stable tag: ${{ steps.list-branches.outputs.new_stable_tag }}"
|
||||
|
||||
- name: Trigger Backport
|
||||
if: steps.checkbox-check.outputs.backport == 'true'
|
||||
run: |
|
||||
set -e
|
||||
echo "Backporting to beta ${{ steps.list-branches.outputs.beta }} and stable ${{ steps.list-branches.outputs.stable }}"
|
||||
|
||||
# Echo the merge commit SHA
|
||||
echo "Merge commit SHA: ${{ github.event.pull_request.merge_commit_sha }}"
|
||||
|
||||
# Fetch all history for all branches and tags
|
||||
git fetch --prune
|
||||
|
||||
# Reset and prepare the beta branch
|
||||
git checkout ${{ steps.list-branches.outputs.beta }}
|
||||
echo "Last 5 commits on beta branch:"
|
||||
git log -n 5 --pretty=format:"%H"
|
||||
echo "" # Newline for formatting
|
||||
|
||||
# Cherry-pick the merge commit from the merged PR
|
||||
git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || {
|
||||
echo "Cherry-pick to beta failed due to conflicts."
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Create new beta branch/tag
|
||||
git tag ${{ steps.list-branches.outputs.new_beta_tag }}
|
||||
# Push the changes and tag to the beta branch using PAT
|
||||
git push origin ${{ steps.list-branches.outputs.beta }}
|
||||
git push origin ${{ steps.list-branches.outputs.new_beta_tag }}
|
||||
|
||||
# Reset and prepare the stable branch
|
||||
git checkout ${{ steps.list-branches.outputs.stable }}
|
||||
echo "Last 5 commits on stable branch:"
|
||||
git log -n 5 --pretty=format:"%H"
|
||||
echo "" # Newline for formatting
|
||||
|
||||
# Cherry-pick the merge commit from the merged PR
|
||||
git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || {
|
||||
echo "Cherry-pick to stable failed due to conflicts."
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Create new stable branch/tag
|
||||
git tag ${{ steps.list-branches.outputs.new_stable_tag }}
|
||||
# Push the changes and tag to the stable branch using PAT
|
||||
git push origin ${{ steps.list-branches.outputs.stable }}
|
||||
git push origin ${{ steps.list-branches.outputs.new_stable_tag }}
|
||||
72
.github/workflows/pr-helm-chart-testing.yml
vendored
72
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -1,72 +0,0 @@
|
||||
name: Helm - Lint and Test Charts
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
jobs:
|
||||
helm-chart-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
|
||||
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4.2.0
|
||||
with:
|
||||
version: v3.14.4
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
|
||||
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
run: |
|
||||
echo "default_branch: ${{ github.event.repository.default_branch }}"
|
||||
changed=$(ct list-changed --remote origin --target-branch ${{ github.event.repository.default_branch }} --chart-dirs deployment/helm/charts)
|
||||
echo "list-changed output: $changed"
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
# rkuo: I don't think we need python?
|
||||
# - name: Set up Python
|
||||
# uses: actions/setup-python@v5
|
||||
# with:
|
||||
# python-version: '3.11'
|
||||
# cache: 'pip'
|
||||
# cache-dependency-path: |
|
||||
# backend/requirements/default.txt
|
||||
# backend/requirements/dev.txt
|
||||
# backend/requirements/model_server.txt
|
||||
# - run: |
|
||||
# python -m pip install --upgrade pip
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --config ct.yaml --all
|
||||
# the following would lint only changed charts, but linting isn't expensive
|
||||
# run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.10.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct install --all --helm-extra-set-args="--set=nginx.enabled=false" --debug --config ct.yaml
|
||||
# the following would install only changed charts, but we only have one chart so
|
||||
# don't worry about that for now
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
68
.github/workflows/pr-helm-chart-testing.yml.disabled.txt
vendored
Normal file
68
.github/workflows/pr-helm-chart-testing.yml.disabled.txt
vendored
Normal file
@@ -0,0 +1,68 @@
|
||||
# This workflow is intentionally disabled while we're still working on it
|
||||
# It's close to ready, but a race condition needs to be fixed with
|
||||
# API server and Vespa startup, and it needs to have a way to build/test against
|
||||
# local containers
|
||||
|
||||
name: Helm - Lint and Test Charts
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
lint-test:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
|
||||
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4.2.0
|
||||
with:
|
||||
version: v3.14.4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
run: |
|
||||
changed=$(ct list-changed --target-branch ${{ github.event.repository.default_branch }})
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Run chart-testing (lint)
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --all --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
- name: Create kind cluster
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.10.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct install --all --config ct.yaml
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
4
.github/workflows/pr-python-checks.yml
vendored
4
.github/workflows/pr-python-checks.yml
vendored
@@ -14,10 +14,10 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
|
||||
@@ -18,11 +18,6 @@ env:
|
||||
# Jira
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
# Google
|
||||
GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR }}
|
||||
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
|
||||
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
@@ -37,7 +32,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
|
||||
4
.github/workflows/pr-python-model-tests.yml
vendored
4
.github/workflows/pr-python-model-tests.yml
vendored
@@ -15,7 +15,7 @@ env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
jobs:
|
||||
model-check:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
@@ -27,7 +27,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
|
||||
2
.github/workflows/pr-python-tests.yml
vendored
2
.github/workflows/pr-python-tests.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
|
||||
2
.github/workflows/pr-quality-checks.yml
vendored
2
.github/workflows/pr-quality-checks.yml
vendored
@@ -18,6 +18,6 @@ jobs:
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
with:
|
||||
extra_args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || '' }}
|
||||
|
||||
300
.vscode/launch.template.jsonc
vendored
300
.vscode/launch.template.jsonc
vendored
@@ -6,69 +6,19 @@
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"compounds": [
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Compound ---",
|
||||
"configurations": [
|
||||
"--- Individual ---"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Run All Danswer Services",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
"Slack Bot",
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery beat",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web / Model / API",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery (all)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery beat"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
}
|
||||
"Indexing",
|
||||
"Background Jobs",
|
||||
"Slack Bot"
|
||||
]
|
||||
}
|
||||
],
|
||||
"configurations": [
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Individual ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web Server",
|
||||
"type": "node",
|
||||
@@ -79,11 +29,7 @@
|
||||
"runtimeArgs": [
|
||||
"run", "dev"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"console": "integratedTerminal",
|
||||
"consoleTitle": "Web Server Console"
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "Model Server",
|
||||
@@ -102,11 +48,7 @@
|
||||
"--reload",
|
||||
"--port",
|
||||
"9000"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Model Server Console"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "API Server",
|
||||
@@ -126,13 +68,43 @@
|
||||
"--reload",
|
||||
"--port",
|
||||
"8080"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "API Server Console"
|
||||
]
|
||||
},
|
||||
// For the listener to access the Slack API,
|
||||
{
|
||||
"name": "Indexing",
|
||||
"consoleName": "Indexing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "danswer/background/update.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
}
|
||||
},
|
||||
// Celery and all async jobs, usually would include indexing as well but this is handled separately above for dev
|
||||
{
|
||||
"name": "Background Jobs",
|
||||
"consoleName": "Background Jobs",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "scripts/dev_run_background_jobs.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"--no-indexing"
|
||||
]
|
||||
},
|
||||
// For the listner to access the Slack API,
|
||||
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
||||
{
|
||||
"name": "Slack Bot",
|
||||
@@ -146,151 +118,7 @@
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Slack Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery primary",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"danswer.background.celery.versioned_apps.primary",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=primary@%n",
|
||||
"-Q",
|
||||
"celery",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery primary Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery light",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"danswer.background.celery.versioned_apps.light",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=64",
|
||||
"--prefetch-multiplier=8",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery light Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery heavy",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"danswer.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery heavy Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery indexing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"danswer.background.celery.versioned_apps.indexing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=indexing@%n",
|
||||
"-Q",
|
||||
"connector_indexing",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery indexing Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery beat",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"danswer.background.celery.versioned_apps.beat",
|
||||
"beat",
|
||||
"--loglevel=INFO",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery beat Console"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
@@ -309,22 +137,8 @@
|
||||
"-v"
|
||||
// Specify a sepcific module/test to run or provide nothing to run all tests
|
||||
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Pytest Console"
|
||||
]
|
||||
},
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Tasks ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "3",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart External Volumes and Containers",
|
||||
"type": "node",
|
||||
@@ -333,27 +147,7 @@
|
||||
"runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"stopOnEntry": true,
|
||||
"presentation": {
|
||||
"group": "3",
|
||||
},
|
||||
},
|
||||
{
|
||||
// Celery jobs launched through a single background script (legacy)
|
||||
// Recommend using the "Celery (all)" compound launch instead.
|
||||
"name": "Background Jobs",
|
||||
"consoleName": "Background Jobs",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "scripts/dev_run_background_jobs.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
},
|
||||
"stopOnEntry": true
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
21
README.md
21
README.md
@@ -1,5 +1,4 @@
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/README.md"} -->
|
||||
<a name="readme-top"></a>
|
||||
|
||||
<h2 align="center">
|
||||
<a href="https://www.danswer.ai/"> <img width="50%" src="https://github.com/danswer-owners/danswer/blob/1fabd9372d66cd54238847197c33f091a724803b/DanswerWithName.png?raw=true)" /></a>
|
||||
@@ -69,13 +68,13 @@ We also have built-in support for deployment on Kubernetes. Files for that can b
|
||||
|
||||
## 🚧 Roadmap
|
||||
* Chat/Prompt sharing with specific teammates and user groups.
|
||||
* Multimodal model support, chat with images, video etc.
|
||||
* Multi-Model model support, chat with images, video etc.
|
||||
* Choosing between LLMs and parameters during chat session.
|
||||
* Tool calling and agent configurations options.
|
||||
* Organizational understanding and ability to locate and suggest experts from your team.
|
||||
|
||||
|
||||
## Other Notable Benefits of Danswer
|
||||
## Other Noteable Benefits of Danswer
|
||||
* User Authentication with document level access management.
|
||||
* Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
|
||||
* Admin Dashboard to configure connectors, document-sets, access, etc.
|
||||
@@ -128,19 +127,3 @@ To try the Danswer Enterprise Edition:
|
||||
|
||||
## 💡 Contributing
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
||||
## ⭐Star History
|
||||
|
||||
[](https://star-history.com/#danswer-ai/danswer&Date)
|
||||
|
||||
## ✨Contributors
|
||||
|
||||
<a href="https://github.com/aryn-ai/sycamore/graphs/contributors">
|
||||
<img alt="contributors" src="https://contrib.rocks/image?repo=danswer-ai/danswer"/>
|
||||
</a>
|
||||
|
||||
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
|
||||
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
|
||||
↑ Back to Top ↑
|
||||
</a>
|
||||
</p>
|
||||
|
||||
@@ -8,11 +8,10 @@ Edition features outside of personal development or testing purposes. Please rea
|
||||
founders@danswer.ai for more information. Please visit https://github.com/danswer-ai/danswer"
|
||||
|
||||
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG DANSWER_VERSION=0.8-dev
|
||||
ARG DANSWER_VERSION=0.3-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
# Install system dependencies
|
||||
# cmake needed for psycopg (postgres)
|
||||
@@ -37,8 +36,6 @@ RUN apt-get update && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
|
||||
|
||||
# Install Python dependencies
|
||||
# Remove py which is pulled in by retry, py is not needed and is a CVE
|
||||
COPY ./requirements/default.txt /tmp/requirements.txt
|
||||
@@ -77,6 +74,7 @@ RUN apt-get update && \
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
|
||||
|
||||
# Pre-downloading NLTK for setups with limited egress
|
||||
RUN python -c "import nltk; \
|
||||
nltk.download('stopwords', quiet=True); \
|
||||
@@ -94,7 +92,6 @@ COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
|
||||
COPY ./danswer /app/danswer
|
||||
COPY ./shared_configs /app/shared_configs
|
||||
COPY ./alembic /app/alembic
|
||||
COPY ./alembic_tenants /app/alembic_tenants
|
||||
COPY ./alembic.ini /app/alembic.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
|
||||
@@ -104,7 +101,7 @@ COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connect
|
||||
# Put logo in assets
|
||||
COPY ./assets /app/assets
|
||||
|
||||
ENV PYTHONPATH=/app
|
||||
ENV PYTHONPATH /app
|
||||
|
||||
# Default command which does nothing
|
||||
# This container is used by api server and background which specify their own CMD
|
||||
|
||||
@@ -7,7 +7,7 @@ You can find it at https://hub.docker.com/r/danswer/danswer-model-server. For mo
|
||||
visit https://github.com/danswer-ai/danswer."
|
||||
|
||||
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG DANSWER_VERSION=0.8-dev
|
||||
ARG DANSWER_VERSION=0.3-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
@@ -55,6 +55,6 @@ COPY ./shared_configs /app/shared_configs
|
||||
# Model Server main code
|
||||
COPY ./model_server /app/model_server
|
||||
|
||||
ENV PYTHONPATH=/app
|
||||
ENV PYTHONPATH /app
|
||||
|
||||
CMD ["uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000"]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[DEFAULT]
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = alembic
|
||||
|
||||
@@ -47,8 +47,7 @@ prepend_sys_path = .
|
||||
# version_path_separator = :
|
||||
# version_path_separator = ;
|
||||
# version_path_separator = space
|
||||
version_path_separator = os
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
@@ -107,12 +106,3 @@ formatter = generic
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
|
||||
|
||||
[alembic]
|
||||
script_location = alembic
|
||||
version_locations = %(script_location)s/versions
|
||||
|
||||
[schema_private]
|
||||
script_location = alembic_tenants
|
||||
version_locations = %(script_location)s/versions
|
||||
|
||||
@@ -1,104 +1,104 @@
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from typing import Any
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
import logging
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.db.models import Base
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from sqlalchemy.schema import SchemaItem
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
# Alembic Config object
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None and config.attributes.get(
|
||||
"configure_logger", True
|
||||
):
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# Add your model's MetaData object here for 'autogenerate' support
|
||||
# Add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def include_object(
|
||||
object: Any, name: str, type_: str, reflected: bool, compare_to: Any
|
||||
) -> bool:
|
||||
"""
|
||||
Determines whether a database object should be included in migrations.
|
||||
Excludes specified tables from migrations.
|
||||
"""
|
||||
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_schema_options() -> tuple[str, bool, bool]:
|
||||
"""
|
||||
Parses command-line options passed via '-x' in Alembic commands.
|
||||
Recognizes 'schema', 'create_schema', and 'upgrade_all_tenants' options.
|
||||
"""
|
||||
def get_schema_options() -> tuple[str, bool]:
|
||||
x_args_raw = context.get_x_argument()
|
||||
x_args = {}
|
||||
for arg in x_args_raw:
|
||||
for pair in arg.split(","):
|
||||
if "=" in pair:
|
||||
key, value = pair.split("=", 1)
|
||||
x_args[key.strip()] = value.strip()
|
||||
schema_name = x_args.get("schema", POSTGRES_DEFAULT_SCHEMA)
|
||||
x_args[key] = value
|
||||
|
||||
schema_name = x_args.get("schema", "public")
|
||||
create_schema = x_args.get("create_schema", "true").lower() == "true"
|
||||
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
|
||||
|
||||
if (
|
||||
MULTI_TENANT
|
||||
and schema_name == POSTGRES_DEFAULT_SCHEMA
|
||||
and not upgrade_all_tenants
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot run default migrations in public schema when multi-tenancy is enabled. "
|
||||
"Please specify a tenant-specific schema."
|
||||
)
|
||||
|
||||
return schema_name, create_schema, upgrade_all_tenants
|
||||
return schema_name, create_schema
|
||||
|
||||
|
||||
def do_run_migrations(
|
||||
connection: Connection, schema_name: str, create_schema: bool
|
||||
) -> None:
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem,
|
||||
name: str,
|
||||
type_: str,
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
) -> bool:
|
||||
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
"""
|
||||
Executes migrations in the specified schema.
|
||||
"""
|
||||
logger.info(f"About to migrate schema: {schema_name}")
|
||||
url = build_connection_string()
|
||||
schema, _ = get_schema_options()
|
||||
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
include_object=include_object,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
schema, create_schema = get_schema_options()
|
||||
if create_schema:
|
||||
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
|
||||
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"'))
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
# Set search_path to the target schema
|
||||
connection.execute(text(f'SET search_path TO "{schema_name}"'))
|
||||
connection.execute(text(f'SET search_path TO "{schema}"'))
|
||||
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
include_object=include_object,
|
||||
version_table_schema=schema_name,
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
compare_type=True,
|
||||
compare_server_default=True,
|
||||
script_location=config.get_main_option("script_location"),
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
@@ -106,98 +106,20 @@ def do_run_migrations(
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""
|
||||
Determines whether to run migrations for a single schema or all schemas,
|
||||
and executes migrations accordingly.
|
||||
"""
|
||||
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
|
||||
|
||||
engine = create_async_engine(
|
||||
"""Run migrations in 'online' mode."""
|
||||
connectable = create_async_engine(
|
||||
build_connection_string(),
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
if upgrade_all_tenants:
|
||||
# Run migrations for all tenant schemas sequentially
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
|
||||
for schema in tenant_schemas:
|
||||
try:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
async with engine.connect() as connection:
|
||||
await connection.run_sync(
|
||||
do_run_migrations,
|
||||
schema_name=schema,
|
||||
create_schema=create_schema,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema}: {e}")
|
||||
raise
|
||||
else:
|
||||
try:
|
||||
logger.info(f"Migrating schema: {schema_name}")
|
||||
async with engine.connect() as connection:
|
||||
await connection.run_sync(
|
||||
do_run_migrations,
|
||||
schema_name=schema_name,
|
||||
create_schema=create_schema,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema_name}: {e}")
|
||||
raise
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""
|
||||
Run migrations in 'offline' mode.
|
||||
"""
|
||||
schema_name, _, upgrade_all_tenants = get_schema_options()
|
||||
url = build_connection_string()
|
||||
|
||||
if upgrade_all_tenants:
|
||||
# Run offline migrations for all tenant schemas
|
||||
engine = create_async_engine(url)
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
engine.sync_engine.dispose()
|
||||
|
||||
for schema in tenant_schemas:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
include_object=include_object,
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
script_location=config.get_main_option("script_location"),
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
else:
|
||||
logger.info(f"Migrating schema: {schema_name}")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
include_object=include_object,
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
script_location=config.get_main_option("script_location"),
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""
|
||||
Runs migrations in 'online' mode using an asynchronous engine.
|
||||
"""
|
||||
"""Run migrations in 'online' mode."""
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
"""add additional data to notifications
|
||||
|
||||
Revision ID: 1b10e1fda030
|
||||
Revises: 6756efa39ada
|
||||
Create Date: 2024-10-15 19:26:44.071259
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1b10e1fda030"
|
||||
down_revision = "6756efa39ada"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"notification", sa.Column("additional_data", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("notification", "additional_data")
|
||||
@@ -1,68 +0,0 @@
|
||||
"""default chosen assistants to none
|
||||
|
||||
Revision ID: 26b931506ecb
|
||||
Revises: 2daa494a0851
|
||||
Create Date: 2024-11-12 13:23:29.858995
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "26b931506ecb"
|
||||
down_revision = "2daa494a0851"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user", sa.Column("chosen_assistants_new", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET chosen_assistants_new =
|
||||
CASE
|
||||
WHEN chosen_assistants = '[-2, -1, 0]' THEN NULL
|
||||
ELSE chosen_assistants
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
op.drop_column("user", "chosen_assistants")
|
||||
|
||||
op.alter_column(
|
||||
"user", "chosen_assistants_new", new_column_name="chosen_assistants"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"chosen_assistants_old",
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
server_default="[-2, -1, 0]",
|
||||
),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET chosen_assistants_old =
|
||||
CASE
|
||||
WHEN chosen_assistants IS NULL THEN '[-2, -1, 0]'::jsonb
|
||||
ELSE chosen_assistants
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
op.drop_column("user", "chosen_assistants")
|
||||
|
||||
op.alter_column(
|
||||
"user", "chosen_assistants_old", new_column_name="chosen_assistants"
|
||||
)
|
||||
@@ -1,30 +0,0 @@
|
||||
"""add-group-sync-time
|
||||
|
||||
Revision ID: 2daa494a0851
|
||||
Revises: c0fd6e4da83a
|
||||
Create Date: 2024-11-11 10:57:22.991157
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2daa494a0851"
|
||||
down_revision = "c0fd6e4da83a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"last_time_external_group_sync",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "last_time_external_group_sync")
|
||||
@@ -1,50 +0,0 @@
|
||||
"""single tool call per message
|
||||
|
||||
Revision ID: 33cb72ea4d80
|
||||
Revises: 5b29123cd710
|
||||
Create Date: 2024-11-01 12:51:01.535003
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "33cb72ea4d80"
|
||||
down_revision = "5b29123cd710"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Step 1: Delete extraneous ToolCall entries
|
||||
# Keep only the ToolCall with the smallest 'id' for each 'message_id'
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM tool_call
|
||||
WHERE id NOT IN (
|
||||
SELECT MIN(id)
|
||||
FROM tool_call
|
||||
WHERE message_id IS NOT NULL
|
||||
GROUP BY message_id
|
||||
);
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 2: Add a unique constraint on message_id
|
||||
op.create_unique_constraint(
|
||||
constraint_name="uq_tool_call_message_id",
|
||||
table_name="tool_call",
|
||||
columns=["message_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Step 1: Drop the unique constraint on message_id
|
||||
op.drop_constraint(
|
||||
constraint_name="uq_tool_call_message_id",
|
||||
table_name="tool_call",
|
||||
type_="unique",
|
||||
)
|
||||
@@ -1,70 +0,0 @@
|
||||
"""nullable search settings for historic index attempts
|
||||
|
||||
Revision ID: 5b29123cd710
|
||||
Revises: 949b4a92a401
|
||||
Create Date: 2024-10-30 19:37:59.630704
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5b29123cd710"
|
||||
down_revision = "949b4a92a401"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the existing foreign key constraint
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Modify the column to be nullable
|
||||
op.alter_column(
|
||||
"index_attempt", "search_settings_id", existing_type=sa.INTEGER(), nullable=True
|
||||
)
|
||||
|
||||
# Add back the foreign key with ON DELETE SET NULL
|
||||
op.create_foreign_key(
|
||||
"fk_index_attempt_search_settings",
|
||||
"index_attempt",
|
||||
"search_settings",
|
||||
["search_settings_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Warning: This will delete all index attempts that don't have search settings
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM index_attempt
|
||||
WHERE search_settings_id IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop foreign key constraint
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Modify the column to be not nullable
|
||||
op.alter_column(
|
||||
"index_attempt",
|
||||
"search_settings_id",
|
||||
existing_type=sa.INTEGER(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Add back the foreign key without ON DELETE SET NULL
|
||||
op.create_foreign_key(
|
||||
"fk_index_attempt_search_settings",
|
||||
"index_attempt",
|
||||
"search_settings",
|
||||
["search_settings_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -1,30 +0,0 @@
|
||||
"""add api_version and deployment_name to search settings
|
||||
|
||||
Revision ID: 5d12a446f5c0
|
||||
Revises: e4334d5b33ba
|
||||
Create Date: 2024-10-08 15:56:07.975636
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5d12a446f5c0"
|
||||
down_revision = "e4334d5b33ba"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"embedding_provider", sa.Column("api_version", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"embedding_provider", sa.Column("deployment_name", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("embedding_provider", "deployment_name")
|
||||
op.drop_column("embedding_provider", "api_version")
|
||||
@@ -1,153 +0,0 @@
|
||||
"""Migrate chat_session and chat_message tables to use UUID primary keys
|
||||
|
||||
Revision ID: 6756efa39ada
|
||||
Revises: 5d12a446f5c0
|
||||
Create Date: 2024-10-15 17:47:44.108537
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "6756efa39ada"
|
||||
down_revision = "5d12a446f5c0"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
"""
|
||||
This script:
|
||||
1. Adds UUID columns to chat_session and chat_message
|
||||
2. Populates new columns with UUIDs
|
||||
3. Updates foreign key relationships
|
||||
4. Removes old integer ID columns
|
||||
|
||||
Note: Downgrade will assign new integer IDs, not restore original ones.
|
||||
"""
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto;")
|
||||
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column(
|
||||
"new_id",
|
||||
sa.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
op.execute("UPDATE chat_session SET new_id = gen_random_uuid();")
|
||||
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("new_chat_session_id", sa.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET new_chat_session_id = cs.new_id
|
||||
FROM chat_session cs
|
||||
WHERE chat_message.chat_session_id = cs.id;
|
||||
"""
|
||||
)
|
||||
|
||||
op.drop_constraint(
|
||||
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
|
||||
)
|
||||
|
||||
op.drop_column("chat_message", "chat_session_id")
|
||||
op.alter_column(
|
||||
"chat_message", "new_chat_session_id", new_column_name="chat_session_id"
|
||||
)
|
||||
|
||||
op.drop_constraint("chat_session_pkey", "chat_session", type_="primary")
|
||||
op.drop_column("chat_session", "id")
|
||||
op.alter_column("chat_session", "new_id", new_column_name="id")
|
||||
|
||||
op.create_primary_key("chat_session_pkey", "chat_session", ["id"])
|
||||
|
||||
op.create_foreign_key(
|
||||
"chat_message_chat_session_id_fkey",
|
||||
"chat_message",
|
||||
"chat_session",
|
||||
["chat_session_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column("old_id", sa.Integer, autoincrement=True, nullable=True),
|
||||
)
|
||||
|
||||
op.execute("CREATE SEQUENCE chat_session_old_id_seq OWNED BY chat_session.old_id;")
|
||||
op.execute(
|
||||
"ALTER TABLE chat_session ALTER COLUMN old_id SET DEFAULT nextval('chat_session_old_id_seq');"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"UPDATE chat_session SET old_id = nextval('chat_session_old_id_seq') WHERE old_id IS NULL;"
|
||||
)
|
||||
|
||||
op.alter_column("chat_session", "old_id", nullable=False)
|
||||
|
||||
op.drop_constraint("chat_session_pkey", "chat_session", type_="primary")
|
||||
op.create_primary_key("chat_session_pkey", "chat_session", ["old_id"])
|
||||
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("old_chat_session_id", sa.Integer, nullable=True),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET old_chat_session_id = cs.old_id
|
||||
FROM chat_session cs
|
||||
WHERE chat_message.chat_session_id = cs.id;
|
||||
"""
|
||||
)
|
||||
|
||||
op.drop_column("chat_message", "chat_session_id")
|
||||
op.alter_column(
|
||||
"chat_message", "old_chat_session_id", new_column_name="chat_session_id"
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"chat_message_chat_session_id_fkey",
|
||||
"chat_message",
|
||||
"chat_session",
|
||||
["chat_session_id"],
|
||||
["old_id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
op.drop_column("chat_session", "id")
|
||||
op.alter_column("chat_session", "old_id", new_column_name="id")
|
||||
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"id",
|
||||
type_=sa.Integer(),
|
||||
existing_type=sa.Integer(),
|
||||
existing_nullable=False,
|
||||
existing_server_default=False,
|
||||
)
|
||||
|
||||
# Rename the sequence
|
||||
op.execute("ALTER SEQUENCE chat_session_old_id_seq RENAME TO chat_session_id_seq;")
|
||||
|
||||
# Update the default value to use the renamed sequence
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"id",
|
||||
server_default=sa.text("nextval('chat_session_id_seq'::regclass)"),
|
||||
)
|
||||
@@ -1,74 +0,0 @@
|
||||
"""remove rt
|
||||
|
||||
Revision ID: 949b4a92a401
|
||||
Revises: 1b10e1fda030
|
||||
Create Date: 2024-10-26 13:06:06.937969
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import your models and constants
|
||||
from danswer.db.models import (
|
||||
Connector,
|
||||
ConnectorCredentialPair,
|
||||
Credential,
|
||||
IndexAttempt,
|
||||
)
|
||||
from danswer.configs.constants import DocumentSource
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "949b4a92a401"
|
||||
down_revision = "1b10e1fda030"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Deletes all RequestTracker connectors and associated data
|
||||
bind = op.get_bind()
|
||||
session = Session(bind=bind)
|
||||
|
||||
connectors_to_delete = (
|
||||
session.query(Connector)
|
||||
.filter(Connector.source == DocumentSource.REQUESTTRACKER)
|
||||
.all()
|
||||
)
|
||||
|
||||
connector_ids = [connector.id for connector in connectors_to_delete]
|
||||
|
||||
if connector_ids:
|
||||
cc_pairs_to_delete = (
|
||||
session.query(ConnectorCredentialPair)
|
||||
.filter(ConnectorCredentialPair.connector_id.in_(connector_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
cc_pair_ids = [cc_pair.id for cc_pair in cc_pairs_to_delete]
|
||||
|
||||
if cc_pair_ids:
|
||||
session.query(IndexAttempt).filter(
|
||||
IndexAttempt.connector_credential_pair_id.in_(cc_pair_ids)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
session.query(ConnectorCredentialPair).filter(
|
||||
ConnectorCredentialPair.id.in_(cc_pair_ids)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
credential_ids = [cc_pair.credential_id for cc_pair in cc_pairs_to_delete]
|
||||
if credential_ids:
|
||||
session.query(Credential).filter(Credential.id.in_(credential_ids)).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
session.query(Connector).filter(Connector.id.in_(connector_ids)).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
session.commit()
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# No-op downgrade as we cannot restore deleted data
|
||||
pass
|
||||
@@ -1,30 +0,0 @@
|
||||
"""add creator to cc pair
|
||||
|
||||
Revision ID: 9cf5c00f72fe
|
||||
Revises: c0fd6e4da83a
|
||||
Create Date: 2024-11-12 15:16:42.682902
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9cf5c00f72fe"
|
||||
down_revision = "26b931506ecb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"creator_id",
|
||||
sa.UUID(as_uuid=True),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "creator_id")
|
||||
@@ -31,12 +31,6 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# First, update any null values to a default value
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET last_attempt_status = 'NOT_STARTED' WHERE last_attempt_status IS NULL"
|
||||
)
|
||||
|
||||
# Then, make the column non-nullable
|
||||
op.alter_column(
|
||||
"connector_credential_pair",
|
||||
"last_attempt_status",
|
||||
|
||||
@@ -288,15 +288,6 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# NOTE: you will lose all chat history. This is to satisfy the non-nullable constraints
|
||||
# below
|
||||
op.execute("DELETE FROM chat_feedback")
|
||||
op.execute("DELETE FROM chat_message__search_doc")
|
||||
op.execute("DELETE FROM document_retrieval_feedback")
|
||||
op.execute("DELETE FROM document_retrieval_feedback")
|
||||
op.execute("DELETE FROM chat_message")
|
||||
op.execute("DELETE FROM chat_session")
|
||||
|
||||
op.drop_constraint(
|
||||
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
|
||||
)
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
"""remove description from starter messages
|
||||
|
||||
Revision ID: b72ed7a5db0e
|
||||
Revises: 33cb72ea4d80
|
||||
Create Date: 2024-11-03 15:55:28.944408
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b72ed7a5db0e"
|
||||
down_revision = "33cb72ea4d80"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET starter_messages = (
|
||||
SELECT jsonb_agg(elem - 'description')
|
||||
FROM jsonb_array_elements(starter_messages) elem
|
||||
)
|
||||
WHERE starter_messages IS NOT NULL
|
||||
AND jsonb_typeof(starter_messages) = 'array'
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET starter_messages = (
|
||||
SELECT jsonb_agg(elem || '{"description": ""}')
|
||||
FROM jsonb_array_elements(starter_messages) elem
|
||||
)
|
||||
WHERE starter_messages IS NOT NULL
|
||||
AND jsonb_typeof(starter_messages) = 'array'
|
||||
"""
|
||||
)
|
||||
)
|
||||
@@ -1,29 +0,0 @@
|
||||
"""add recent assistants
|
||||
|
||||
Revision ID: c0fd6e4da83a
|
||||
Revises: b72ed7a5db0e
|
||||
Create Date: 2024-11-03 17:28:54.916618
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c0fd6e4da83a"
|
||||
down_revision = "b72ed7a5db0e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "recent_assistants")
|
||||
@@ -23,56 +23,6 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Delete chat messages and feedback first since they reference chat sessions
|
||||
# Get chat messages from sessions with null persona_id
|
||||
chat_messages_query = """
|
||||
SELECT id
|
||||
FROM chat_message
|
||||
WHERE chat_session_id IN (
|
||||
SELECT id
|
||||
FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
)
|
||||
"""
|
||||
|
||||
# Delete dependent records first
|
||||
op.execute(
|
||||
f"""
|
||||
DELETE FROM document_retrieval_feedback
|
||||
WHERE chat_message_id IN (
|
||||
{chat_messages_query}
|
||||
)
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
f"""
|
||||
DELETE FROM chat_message__search_doc
|
||||
WHERE chat_message_id IN (
|
||||
{chat_messages_query}
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Delete chat messages
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM chat_message
|
||||
WHERE chat_session_id IN (
|
||||
SELECT id
|
||||
FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Now we can safely delete the chat sessions
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"persona_id",
|
||||
|
||||
@@ -20,7 +20,7 @@ depends_on: None = None
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
existing_ids_and_chosen_assistants = conn.execute(
|
||||
sa.text('select id, chosen_assistants from "user"')
|
||||
sa.text("select id, chosen_assistants from public.user")
|
||||
)
|
||||
op.drop_column(
|
||||
"user",
|
||||
@@ -37,7 +37,7 @@ def upgrade() -> None:
|
||||
for id, chosen_assistants in existing_ids_and_chosen_assistants:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
'update "user" set chosen_assistants = :chosen_assistants where id = :id'
|
||||
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
|
||||
),
|
||||
{"chosen_assistants": json.dumps(chosen_assistants), "id": id},
|
||||
)
|
||||
@@ -46,7 +46,7 @@ def upgrade() -> None:
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
existing_ids_and_chosen_assistants = conn.execute(
|
||||
sa.text('select id, chosen_assistants from "user"')
|
||||
sa.text("select id, chosen_assistants from public.user")
|
||||
)
|
||||
op.drop_column(
|
||||
"user",
|
||||
@@ -59,7 +59,7 @@ def downgrade() -> None:
|
||||
for id, chosen_assistants in existing_ids_and_chosen_assistants:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
'update "user" set chosen_assistants = :chosen_assistants where id = :id'
|
||||
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
|
||||
),
|
||||
{"chosen_assistants": chosen_assistants, "id": id},
|
||||
)
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
"""add_deployment_name_to_llmprovider
|
||||
|
||||
Revision ID: e4334d5b33ba
|
||||
Revises: ac5eaac849f9
|
||||
Create Date: 2024-10-04 09:52:34.896867
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e4334d5b33ba"
|
||||
down_revision = "ac5eaac849f9"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"llm_provider", sa.Column("deployment_name", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("llm_provider", "deployment_name")
|
||||
@@ -1,3 +0,0 @@
|
||||
These files are for public table migrations when operating with multi tenancy.
|
||||
|
||||
If you are not a Danswer developer, you can ignore this directory entirely.
|
||||
@@ -1,111 +0,0 @@
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.schema import SchemaItem
|
||||
|
||||
from alembic import context
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.db.models import PublicBase
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None and config.attributes.get(
|
||||
"configure_logger", True
|
||||
):
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
target_metadata = [PublicBase.metadata]
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem,
|
||||
name: str,
|
||||
type_: str,
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
) -> bool:
|
||||
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = build_connection_string()
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
include_object=include_object,
|
||||
) # type: ignore
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
|
||||
connectable = create_async_engine(
|
||||
build_connection_string(),
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode."""
|
||||
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
@@ -1,24 +0,0 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = ${repr(up_revision)}
|
||||
down_revision = ${repr(down_revision)}
|
||||
branch_labels = ${repr(branch_labels)}
|
||||
depends_on = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -1,24 +0,0 @@
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "14a83a331951"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"user_tenant_mapping",
|
||||
sa.Column("email", sa.String(), nullable=False),
|
||||
sa.Column("tenant_id", sa.String(), nullable=False),
|
||||
sa.UniqueConstraint("email", "tenant_id", name="uq_user_tenant"),
|
||||
sa.UniqueConstraint("email", name="uq_email"),
|
||||
schema="public",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("user_tenant_mapping", schema="public")
|
||||
@@ -1,3 +1,3 @@
|
||||
import os
|
||||
|
||||
__version__ = os.environ.get("DANSWER_VERSION", "") or "Development"
|
||||
__version__ = os.environ.get("DANSWER_VERSION", "") or "0.3-dev"
|
||||
|
||||
@@ -16,41 +16,6 @@ class ExternalAccess:
|
||||
is_public: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocExternalAccess:
|
||||
external_access: ExternalAccess
|
||||
# The document ID
|
||||
doc_id: str
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"external_access": {
|
||||
"external_user_emails": list(self.external_access.external_user_emails),
|
||||
"external_user_group_ids": list(
|
||||
self.external_access.external_user_group_ids
|
||||
),
|
||||
"is_public": self.external_access.is_public,
|
||||
},
|
||||
"doc_id": self.doc_id,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "DocExternalAccess":
|
||||
external_access = ExternalAccess(
|
||||
external_user_emails=set(
|
||||
data["external_access"].get("external_user_emails", [])
|
||||
),
|
||||
external_user_group_ids=set(
|
||||
data["external_access"].get("external_user_group_ids", [])
|
||||
),
|
||||
is_public=data["external_access"]["is_public"],
|
||||
)
|
||||
return cls(
|
||||
external_access=external_access,
|
||||
doc_id=data["doc_id"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocumentAccess(ExternalAccess):
|
||||
# User emails for Danswer users, None indicates admin
|
||||
@@ -105,12 +70,3 @@ class DocumentAccess(ExternalAccess):
|
||||
user_groups=set(user_groups),
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
|
||||
default_public_access = DocumentAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
user_emails=set(),
|
||||
user_groups=set(),
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
import secrets
|
||||
import uuid
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import unquote
|
||||
|
||||
from fastapi import Request
|
||||
from passlib.hash import sha256_crypt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.app_configs import API_KEY_HASH_ROUNDS
|
||||
|
||||
|
||||
_API_KEY_HEADER_NAME = "Authorization"
|
||||
# NOTE for others who are curious: In the context of a header, "X-" often refers
|
||||
# to non-standard, experimental, or custom headers in HTTP or other protocols. It
|
||||
# indicates that the header is not part of the official standards defined by
|
||||
# organizations like the Internet Engineering Task Force (IETF).
|
||||
_API_KEY_HEADER_ALTERNATIVE_NAME = "X-Danswer-Authorization"
|
||||
_BEARER_PREFIX = "Bearer "
|
||||
_API_KEY_PREFIX = "dn_"
|
||||
_API_KEY_LEN = 192
|
||||
|
||||
|
||||
class ApiKeyDescriptor(BaseModel):
|
||||
api_key_id: int
|
||||
api_key_display: str
|
||||
api_key: str | None = None # only present on initial creation
|
||||
api_key_name: str | None = None
|
||||
api_key_role: UserRole
|
||||
|
||||
user_id: uuid.UUID
|
||||
|
||||
|
||||
def generate_api_key(tenant_id: str | None = None) -> str:
|
||||
# For backwards compatibility, if no tenant_id, generate old style key
|
||||
if not tenant_id:
|
||||
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
|
||||
|
||||
encoded_tenant = quote(tenant_id) # URL encode the tenant ID
|
||||
return f"{_API_KEY_PREFIX}{encoded_tenant}.{secrets.token_urlsafe(_API_KEY_LEN)}"
|
||||
|
||||
|
||||
def extract_tenant_from_api_key_header(request: Request) -> str | None:
|
||||
"""Extract tenant ID from request. Returns None if auth is disabled or invalid format."""
|
||||
raw_api_key_header = request.headers.get(
|
||||
_API_KEY_HEADER_ALTERNATIVE_NAME
|
||||
) or request.headers.get(_API_KEY_HEADER_NAME)
|
||||
|
||||
if not raw_api_key_header or not raw_api_key_header.startswith(_BEARER_PREFIX):
|
||||
return None
|
||||
|
||||
api_key = raw_api_key_header[len(_BEARER_PREFIX) :].strip()
|
||||
|
||||
if not api_key.startswith(_API_KEY_PREFIX):
|
||||
return None
|
||||
|
||||
parts = api_key[len(_API_KEY_PREFIX) :].split(".", 1)
|
||||
if len(parts) != 2:
|
||||
return None
|
||||
|
||||
tenant_id = parts[0]
|
||||
return unquote(tenant_id) if tenant_id else None
|
||||
|
||||
|
||||
def hash_api_key(api_key: str) -> str:
|
||||
# NOTE: no salt is needed, as the API key is randomly generated
|
||||
# and overlaps are impossible
|
||||
return sha256_crypt.hash(api_key, salt="", rounds=API_KEY_HASH_ROUNDS)
|
||||
|
||||
|
||||
def build_displayable_api_key(api_key: str) -> str:
|
||||
if api_key.startswith(_API_KEY_PREFIX):
|
||||
api_key = api_key[len(_API_KEY_PREFIX) :]
|
||||
|
||||
return _API_KEY_PREFIX + api_key[:4] + "********" + api_key[-4:]
|
||||
|
||||
|
||||
def get_hashed_api_key_from_request(request: Request) -> str | None:
|
||||
raw_api_key_header = request.headers.get(
|
||||
_API_KEY_HEADER_ALTERNATIVE_NAME
|
||||
) or request.headers.get(_API_KEY_HEADER_NAME)
|
||||
if raw_api_key_header is None:
|
||||
return None
|
||||
|
||||
if raw_api_key_header.startswith(_BEARER_PREFIX):
|
||||
raw_api_key_header = raw_api_key_header[len(_BEARER_PREFIX) :].strip()
|
||||
|
||||
return hash_api_key(raw_api_key_header)
|
||||
@@ -9,7 +9,6 @@ from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
def get_invited_users() -> list[str]:
|
||||
try:
|
||||
store = get_kv_store()
|
||||
|
||||
return cast(list, store.load(KV_USER_STORE_KEY))
|
||||
except KvKeyNotFoundError:
|
||||
return list()
|
||||
|
||||
@@ -15,7 +15,6 @@ class UserRole(str, Enum):
|
||||
for all groups they are a member of
|
||||
"""
|
||||
|
||||
LIMITED = "limited"
|
||||
BASIC = "basic"
|
||||
ADMIN = "admin"
|
||||
CURATOR = "curator"
|
||||
@@ -35,7 +34,6 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
role: UserRole = UserRole.BASIC
|
||||
has_web_login: bool | None = True
|
||||
tenant_id: str | None = None
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
|
||||
@@ -5,23 +5,18 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import jwt
|
||||
from email_validator import EmailNotValidError
|
||||
from email_validator import EmailUndeliverableError
|
||||
from email_validator import validate_email
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi_users import BaseUserManager
|
||||
from fastapi_users import exceptions
|
||||
@@ -31,36 +26,22 @@ from fastapi_users import schemas
|
||||
from fastapi_users import UUIDIDMixin
|
||||
from fastapi_users.authentication import AuthenticationBackend
|
||||
from fastapi_users.authentication import CookieTransport
|
||||
from fastapi_users.authentication import JWTStrategy
|
||||
from fastapi_users.authentication import Strategy
|
||||
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
||||
from fastapi_users.exceptions import UserAlreadyExists
|
||||
from fastapi_users.jwt import decode_jwt
|
||||
from fastapi_users.jwt import generate_jwt
|
||||
from fastapi_users.jwt import SecretType
|
||||
from fastapi_users.manager import UserManagerDependency
|
||||
from fastapi_users.openapi import OpenAPIResponseType
|
||||
from fastapi_users.router.common import ErrorCode
|
||||
from fastapi_users.router.common import ErrorModel
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
|
||||
from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
|
||||
from httpx_oauth.oauth2 import BaseOAuth2
|
||||
from httpx_oauth.oauth2 import OAuth2Token
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import attributes
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.api_key import get_hashed_api_key_from_request
|
||||
from danswer.auth.invited_users import get_invited_users
|
||||
from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DATA_PLANE_SECRET
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
from danswer.configs.app_configs import DISABLE_VERIFICATION
|
||||
from danswer.configs.app_configs import EMAIL_FROM
|
||||
from danswer.configs.app_configs import EXPECTED_API_KEY
|
||||
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from danswer.configs.app_configs import SMTP_PASS
|
||||
@@ -75,36 +56,23 @@ from danswer.configs.constants import AuthType
|
||||
from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from danswer.db.api_key import fetch_user_for_api_key
|
||||
from danswer.db.auth import get_access_token_db
|
||||
from danswer.db.auth import get_default_admin_user_emails
|
||||
from danswer.db.auth import get_user_count
|
||||
from danswer.db.auth import get_user_db
|
||||
from danswer.db.auth import SQLAlchemyUserAdminDB
|
||||
from danswer.db.engine import get_async_session_with_tenant
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import OAuthAccount
|
||||
from danswer.db.models import User
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import async_return_default_schema
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
@@ -139,9 +107,7 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
|
||||
def user_needs_to_be_verified() -> bool:
|
||||
# all other auth types besides basic should require users to be
|
||||
# verified
|
||||
return not DISABLE_VERIFICATION and (
|
||||
AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||
)
|
||||
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
@@ -152,10 +118,7 @@ def verify_email_is_invited(email: str) -> None:
|
||||
if not email:
|
||||
raise PermissionError("Email must be specified")
|
||||
|
||||
try:
|
||||
email_info = validate_email(email)
|
||||
except EmailUndeliverableError:
|
||||
raise PermissionError("Email is not valid")
|
||||
email_info = validate_email(email) # can raise EmailNotValidError
|
||||
|
||||
for email_whitelist in whitelist:
|
||||
try:
|
||||
@@ -173,8 +136,8 @@ def verify_email_is_invited(email: str) -> None:
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
def verify_email_in_whitelist(email: str) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
if not get_user_by_email(email, db_session):
|
||||
verify_email_is_invited(email)
|
||||
|
||||
@@ -228,63 +191,35 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> User:
|
||||
referral_source = None
|
||||
if request is not None:
|
||||
referral_source = request.cookies.get("referral_source", None)
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=user_create.email,
|
||||
referral_source=referral_source,
|
||||
)
|
||||
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
if MULTI_TENANT:
|
||||
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
|
||||
self.user_db = tenant_user_db
|
||||
self.database = tenant_user_db
|
||||
|
||||
if hasattr(user_create, "role"):
|
||||
user_count = await get_user_count()
|
||||
if (
|
||||
user_count == 0
|
||||
or user_create.email in get_default_admin_user_emails()
|
||||
):
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if (
|
||||
not user.has_web_login
|
||||
and hasattr(user_create, "has_web_login")
|
||||
and user_create.has_web_login
|
||||
):
|
||||
user_update = UserUpdate(
|
||||
password=user_create.password,
|
||||
has_web_login=True,
|
||||
role=user_create.role,
|
||||
is_verified=user_create.is_verified,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
else:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return user
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
if hasattr(user_create, "role"):
|
||||
user_count = await get_user_count()
|
||||
if user_count == 0 or user_create.email in get_default_admin_user_emails():
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
user = None
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if (
|
||||
not user.has_web_login
|
||||
and hasattr(user_create, "has_web_login")
|
||||
and user_create.has_web_login
|
||||
):
|
||||
user_update = UserUpdate(
|
||||
password=user_create.password,
|
||||
has_web_login=True,
|
||||
role=user_create.role,
|
||||
is_verified=user_create.is_verified,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
else:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
return user
|
||||
|
||||
async def oauth_callback(
|
||||
self: "BaseUserManager[models.UOAP, models.ID]",
|
||||
@@ -299,122 +234,45 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
) -> models.UOAP:
|
||||
referral_source = None
|
||||
if request:
|
||||
referral_source = getattr(request.state, "referral_source", None)
|
||||
verify_email_in_whitelist(account_email)
|
||||
verify_email_domain(account_email)
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=account_email,
|
||||
referral_source=referral_source,
|
||||
user = await super().oauth_callback( # type: ignore
|
||||
oauth_name=oauth_name,
|
||||
access_token=access_token,
|
||||
account_id=account_id,
|
||||
account_email=account_email,
|
||||
expires_at=expires_at,
|
||||
refresh_token=refresh_token,
|
||||
request=request,
|
||||
associate_by_email=associate_by_email,
|
||||
is_verified_by_default=is_verified_by_default,
|
||||
)
|
||||
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
||||
# re-authenticate that frequently, so by default this is disabled
|
||||
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
|
||||
await self.user_db.update(user, update_dict={"oidc_expiry": oidc_expiry})
|
||||
|
||||
# Proceed with the tenant context
|
||||
token = None
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
|
||||
# otherwise, the oidc expiry will always be old, and the user will never be able to login
|
||||
if user.oidc_expiry and not TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
await self.user_db.update(user, update_dict={"oidc_expiry": None})
|
||||
|
||||
verify_email_in_whitelist(account_email, tenant_id)
|
||||
verify_email_domain(account_email)
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if not user.has_web_login:
|
||||
await self.user_db.update(
|
||||
user,
|
||||
update_dict={
|
||||
"is_verified": is_verified_by_default,
|
||||
"has_web_login": True,
|
||||
},
|
||||
)
|
||||
user.is_verified = is_verified_by_default
|
||||
user.has_web_login = True
|
||||
|
||||
if MULTI_TENANT:
|
||||
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
|
||||
self.user_db = tenant_user_db
|
||||
self.database = tenant_user_db # type: ignore
|
||||
|
||||
oauth_account_dict = {
|
||||
"oauth_name": oauth_name,
|
||||
"access_token": access_token,
|
||||
"account_id": account_id,
|
||||
"account_email": account_email,
|
||||
"expires_at": expires_at,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
try:
|
||||
# Attempt to get user by OAuth account
|
||||
user = await self.get_by_oauth_account(oauth_name, account_id)
|
||||
|
||||
except exceptions.UserNotExists:
|
||||
try:
|
||||
# Attempt to get user by email
|
||||
user = await self.get_by_email(account_email)
|
||||
if not associate_by_email:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
|
||||
# If user not found by OAuth account or email, create a new user
|
||||
except exceptions.UserNotExists:
|
||||
password = self.password_helper.generate()
|
||||
user_dict = {
|
||||
"email": account_email,
|
||||
"hashed_password": self.password_helper.hash(password),
|
||||
"is_verified": is_verified_by_default,
|
||||
}
|
||||
|
||||
user = await self.user_db.create(user_dict)
|
||||
|
||||
# Explicitly set the Postgres schema for this session to ensure
|
||||
# OAuth account creation happens in the correct tenant schema
|
||||
await db_session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
|
||||
# Add OAuth account
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
await self.on_after_register(user, request)
|
||||
|
||||
else:
|
||||
for existing_oauth_account in user.oauth_accounts:
|
||||
if (
|
||||
existing_oauth_account.account_id == account_id
|
||||
and existing_oauth_account.oauth_name == oauth_name
|
||||
):
|
||||
user = await self.user_db.update_oauth_account(
|
||||
user, existing_oauth_account, oauth_account_dict
|
||||
)
|
||||
|
||||
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
||||
# re-authenticate that frequently, so by default this is disabled
|
||||
|
||||
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
|
||||
await self.user_db.update(
|
||||
user, update_dict={"oidc_expiry": oidc_expiry}
|
||||
)
|
||||
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if not user.has_web_login: # type: ignore
|
||||
await self.user_db.update(
|
||||
user,
|
||||
{
|
||||
"is_verified": is_verified_by_default,
|
||||
"has_web_login": True,
|
||||
},
|
||||
)
|
||||
user.is_verified = is_verified_by_default
|
||||
user.has_web_login = True # type: ignore
|
||||
|
||||
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
|
||||
# otherwise, the oidc expiry will always be old, and the user will never be able to login
|
||||
if (
|
||||
user.oidc_expiry is not None # type: ignore
|
||||
and not TRACK_EXTERNAL_IDP_EXPIRY
|
||||
):
|
||||
await self.user_db.update(user, {"oidc_expiry": None})
|
||||
user.oidc_expiry = None # type: ignore
|
||||
|
||||
if token:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return user
|
||||
return user
|
||||
|
||||
async def on_after_register(
|
||||
self, user: User, request: Optional[Request] = None
|
||||
@@ -445,55 +303,28 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
async def authenticate(
|
||||
self, credentials: OAuth2PasswordRequestForm
|
||||
) -> Optional[User]:
|
||||
email = credentials.username
|
||||
|
||||
# Get tenant_id from mapping table
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=email,
|
||||
)
|
||||
if not tenant_id:
|
||||
# User not found in mapping
|
||||
try:
|
||||
user = await self.get_by_email(credentials.username)
|
||||
except exceptions.UserNotExists:
|
||||
self.password_helper.hash(credentials.password)
|
||||
return None
|
||||
|
||||
# Create a tenant-specific session
|
||||
async with get_async_session_with_tenant(tenant_id) as tenant_session:
|
||||
tenant_user_db: SQLAlchemyUserDatabase = SQLAlchemyUserDatabase(
|
||||
tenant_session, User
|
||||
if not user.has_web_login:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
||||
)
|
||||
self.user_db = tenant_user_db
|
||||
|
||||
# Proceed with authentication
|
||||
try:
|
||||
user = await self.get_by_email(email)
|
||||
verified, updated_password_hash = self.password_helper.verify_and_update(
|
||||
credentials.password, user.hashed_password
|
||||
)
|
||||
if not verified:
|
||||
return None
|
||||
|
||||
except exceptions.UserNotExists:
|
||||
self.password_helper.hash(credentials.password)
|
||||
return None
|
||||
if updated_password_hash is not None:
|
||||
await self.user_db.update(user, {"hashed_password": updated_password_hash})
|
||||
|
||||
has_web_login = attributes.get_attribute(user, "has_web_login")
|
||||
|
||||
if not has_web_login:
|
||||
raise BasicAuthenticationError(
|
||||
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
||||
)
|
||||
|
||||
verified, updated_password_hash = self.password_helper.verify_and_update(
|
||||
credentials.password, user.hashed_password
|
||||
)
|
||||
if not verified:
|
||||
return None
|
||||
|
||||
if updated_password_hash is not None:
|
||||
await self.user_db.update(
|
||||
user, {"hashed_password": updated_password_hash}
|
||||
)
|
||||
|
||||
return user
|
||||
return user
|
||||
|
||||
|
||||
async def get_user_manager(
|
||||
@@ -508,51 +339,20 @@ cookie_transport = CookieTransport(
|
||||
)
|
||||
|
||||
|
||||
# This strategy is used to add tenant_id to the JWT token
|
||||
class TenantAwareJWTStrategy(JWTStrategy):
|
||||
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=user.email,
|
||||
)
|
||||
|
||||
data = {
|
||||
"sub": str(user.id),
|
||||
"aud": self.token_audience,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
return data
|
||||
|
||||
async def write_token(self, user: User) -> str:
|
||||
data = await self._create_token_data(user)
|
||||
return generate_jwt(
|
||||
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
|
||||
)
|
||||
|
||||
|
||||
def get_jwt_strategy() -> TenantAwareJWTStrategy:
|
||||
return TenantAwareJWTStrategy(
|
||||
secret=USER_AUTH_SECRET,
|
||||
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> DatabaseStrategy:
|
||||
return DatabaseStrategy(
|
||||
strategy = DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
|
||||
)
|
||||
return strategy
|
||||
|
||||
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="jwt" if MULTI_TENANT else "database",
|
||||
name="database",
|
||||
transport=cookie_transport,
|
||||
get_strategy=get_jwt_strategy if MULTI_TENANT else get_database_strategy, # type: ignore
|
||||
) # type: ignore
|
||||
get_strategy=get_database_strategy,
|
||||
)
|
||||
|
||||
|
||||
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
||||
@@ -566,11 +366,9 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
||||
This way the login router does not need to be included
|
||||
"""
|
||||
router = APIRouter()
|
||||
|
||||
get_current_user_token = self.authenticator.current_user_token(
|
||||
active=True, verified=requires_verification
|
||||
)
|
||||
|
||||
logout_responses: OpenAPIResponseType = {
|
||||
**{
|
||||
status.HTTP_401_UNAUTHORIZED: {
|
||||
@@ -617,8 +415,8 @@ async def optional_user_(
|
||||
|
||||
async def optional_user(
|
||||
request: Request,
|
||||
db_session: Session = Depends(get_session),
|
||||
user: User | None = Depends(optional_fastapi_current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> User | None:
|
||||
versioned_fetch_user = fetch_versioned_implementation(
|
||||
"danswer.auth.users", "optional_user_"
|
||||
@@ -635,12 +433,14 @@ async def double_check_user(
|
||||
return None
|
||||
|
||||
if user is None:
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not authenticated.",
|
||||
)
|
||||
|
||||
if user_needs_to_be_verified() and not user.is_verified:
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
@@ -649,7 +449,8 @@ async def double_check_user(
|
||||
and user.oidc_expiry < datetime.now(timezone.utc)
|
||||
and not include_expired
|
||||
):
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
@@ -662,24 +463,10 @@ async def current_user_with_expired_token(
|
||||
return await double_check_user(user, include_expired=True)
|
||||
|
||||
|
||||
async def current_limited_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
return await double_check_user(user)
|
||||
|
||||
|
||||
async def current_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
user = await double_check_user(user)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
if user.role == UserRole.LIMITED:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
|
||||
)
|
||||
return user
|
||||
return await double_check_user(user)
|
||||
|
||||
|
||||
async def current_curator_or_admin_user(
|
||||
@@ -689,13 +476,15 @@ async def current_curator_or_admin_user(
|
||||
return None
|
||||
|
||||
if not user or not hasattr(user, "role"):
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not authenticated or lacks role information.",
|
||||
)
|
||||
|
||||
allowed_roles = {UserRole.GLOBAL_CURATOR, UserRole.CURATOR, UserRole.ADMIN}
|
||||
if user.role not in allowed_roles:
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not a curator or admin.",
|
||||
)
|
||||
|
||||
@@ -707,7 +496,8 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
|
||||
return None
|
||||
|
||||
if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN:
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User must be an admin to perform this action.",
|
||||
)
|
||||
|
||||
@@ -719,210 +509,26 @@ def get_default_admin_user_emails_() -> list[str]:
|
||||
return []
|
||||
|
||||
|
||||
STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
|
||||
|
||||
|
||||
class OAuth2AuthorizeResponse(BaseModel):
|
||||
authorization_url: str
|
||||
|
||||
|
||||
def generate_state_token(
|
||||
data: Dict[str, str], secret: SecretType, lifetime_seconds: int = 3600
|
||||
) -> str:
|
||||
data["aud"] = STATE_TOKEN_AUDIENCE
|
||||
|
||||
return generate_jwt(data, secret, lifetime_seconds)
|
||||
|
||||
|
||||
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
|
||||
def create_danswer_oauth_router(
|
||||
oauth_client: BaseOAuth2,
|
||||
backend: AuthenticationBackend,
|
||||
state_secret: SecretType,
|
||||
redirect_url: Optional[str] = None,
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
) -> APIRouter:
|
||||
return get_oauth_router(
|
||||
oauth_client,
|
||||
backend,
|
||||
get_user_manager,
|
||||
state_secret,
|
||||
redirect_url,
|
||||
associate_by_email,
|
||||
is_verified_by_default,
|
||||
)
|
||||
|
||||
|
||||
def get_oauth_router(
|
||||
oauth_client: BaseOAuth2,
|
||||
backend: AuthenticationBackend,
|
||||
get_user_manager: UserManagerDependency[models.UP, models.ID],
|
||||
state_secret: SecretType,
|
||||
redirect_url: Optional[str] = None,
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
) -> APIRouter:
|
||||
"""Generate a router with the OAuth routes."""
|
||||
router = APIRouter()
|
||||
callback_route_name = f"oauth:{oauth_client.name}.{backend.name}.callback"
|
||||
|
||||
if redirect_url is not None:
|
||||
oauth2_authorize_callback = OAuth2AuthorizeCallback(
|
||||
oauth_client,
|
||||
redirect_url=redirect_url,
|
||||
)
|
||||
else:
|
||||
oauth2_authorize_callback = OAuth2AuthorizeCallback(
|
||||
oauth_client,
|
||||
route_name=callback_route_name,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/authorize",
|
||||
name=f"oauth:{oauth_client.name}.{backend.name}.authorize",
|
||||
response_model=OAuth2AuthorizeResponse,
|
||||
)
|
||||
async def authorize(
|
||||
request: Request,
|
||||
scopes: List[str] = Query(None),
|
||||
) -> OAuth2AuthorizeResponse:
|
||||
referral_source = request.cookies.get("referral_source", None)
|
||||
|
||||
if redirect_url is not None:
|
||||
authorize_redirect_url = redirect_url
|
||||
else:
|
||||
authorize_redirect_url = str(request.url_for(callback_route_name))
|
||||
|
||||
next_url = request.query_params.get("next", "/")
|
||||
|
||||
state_data: Dict[str, str] = {
|
||||
"next_url": next_url,
|
||||
"referral_source": referral_source or "default_referral",
|
||||
}
|
||||
state = generate_state_token(state_data, state_secret)
|
||||
authorization_url = await oauth_client.get_authorization_url(
|
||||
authorize_redirect_url,
|
||||
state,
|
||||
scopes,
|
||||
)
|
||||
|
||||
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
|
||||
|
||||
@router.get(
|
||||
"/callback",
|
||||
name=callback_route_name,
|
||||
description="The response varies based on the authentication backend used.",
|
||||
responses={
|
||||
status.HTTP_400_BAD_REQUEST: {
|
||||
"model": ErrorModel,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"examples": {
|
||||
"INVALID_STATE_TOKEN": {
|
||||
"summary": "Invalid state token.",
|
||||
"value": None,
|
||||
},
|
||||
ErrorCode.LOGIN_BAD_CREDENTIALS: {
|
||||
"summary": "User is inactive.",
|
||||
"value": {"detail": ErrorCode.LOGIN_BAD_CREDENTIALS},
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
async def callback(
|
||||
request: Request,
|
||||
access_token_state: Tuple[OAuth2Token, str] = Depends(
|
||||
oauth2_authorize_callback
|
||||
),
|
||||
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
|
||||
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
|
||||
) -> RedirectResponse:
|
||||
token, state = access_token_state
|
||||
account_id, account_email = await oauth_client.get_id_email(
|
||||
token["access_token"]
|
||||
)
|
||||
|
||||
if account_email is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL,
|
||||
)
|
||||
|
||||
try:
|
||||
state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE])
|
||||
except jwt.DecodeError:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
next_url = state_data.get("next_url", "/")
|
||||
referral_source = state_data.get("referral_source", None)
|
||||
|
||||
request.state.referral_source = referral_source
|
||||
|
||||
# Proceed to authenticate or create the user
|
||||
try:
|
||||
user = await user_manager.oauth_callback(
|
||||
oauth_client.name,
|
||||
token["access_token"],
|
||||
account_id,
|
||||
account_email,
|
||||
token.get("expires_at"),
|
||||
token.get("refresh_token"),
|
||||
request,
|
||||
associate_by_email=associate_by_email,
|
||||
is_verified_by_default=is_verified_by_default,
|
||||
)
|
||||
except UserAlreadyExists:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.OAUTH_USER_ALREADY_EXISTS,
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.LOGIN_BAD_CREDENTIALS,
|
||||
)
|
||||
|
||||
# Login user
|
||||
response = await backend.login(strategy, user)
|
||||
await user_manager.on_after_login(user, request, response)
|
||||
|
||||
# Prepare redirect response
|
||||
redirect_response = RedirectResponse(next_url, status_code=302)
|
||||
|
||||
# Copy headers and other attributes from 'response' to 'redirect_response'
|
||||
for header_name, header_value in response.headers.items():
|
||||
redirect_response.headers[header_name] = header_value
|
||||
|
||||
if hasattr(response, "body"):
|
||||
redirect_response.body = response.body
|
||||
if hasattr(response, "status_code"):
|
||||
redirect_response.status_code = response.status_code
|
||||
if hasattr(response, "media_type"):
|
||||
redirect_response.media_type = response.media_type
|
||||
return redirect_response
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def api_key_dep(
|
||||
request: Request, db_session: Session = Depends(get_session)
|
||||
) -> User | None:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return None
|
||||
|
||||
hashed_api_key = get_hashed_api_key_from_request(request)
|
||||
if not hashed_api_key:
|
||||
raise HTTPException(status_code=401, detail="Missing API key")
|
||||
|
||||
if hashed_api_key:
|
||||
user = fetch_user_for_api_key(hashed_api_key, db_session)
|
||||
|
||||
if user is None:
|
||||
async def control_plane_dep(request: Request) -> None:
|
||||
api_key = request.headers.get("X-API-KEY")
|
||||
if api_key != EXPECTED_API_KEY:
|
||||
logger.warning("Invalid API key")
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
return user
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
logger.warning("Invalid authorization header")
|
||||
raise HTTPException(status_code=401, detail="Invalid authorization header")
|
||||
|
||||
token = auth_header.split(" ")[1]
|
||||
try:
|
||||
payload = jwt.decode(token, DATA_PLANE_SECRET, algorithms=["HS256"])
|
||||
if payload.get("scope") != "tenant:create":
|
||||
logger.warning("Insufficient permissions")
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.warning("Token has expired")
|
||||
raise HTTPException(status_code=401, detail="Token has expired")
|
||||
except jwt.InvalidTokenError:
|
||||
logger.warning("Invalid token")
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
@@ -1,402 +0,0 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import sentry_sdk
|
||||
from celery import Task
|
||||
from celery.app import trace
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.states import READY_STATES
|
||||
from celery.utils.log import get_task_logger
|
||||
from celery.worker import strategy # type: ignore
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
|
||||
from danswer.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
|
||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from danswer.redis.redis_document_set import RedisDocumentSet
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.redis.redis_usergroup import RedisUserGroup
|
||||
from danswer.utils.logger import ColoredFormatter
|
||||
from danswer.utils.logger import PlainFormatter
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
if SENTRY_DSN:
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[CeleryIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
|
||||
|
||||
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
"""We handle this signal in order to remove completed tasks
|
||||
from their respective tasksets. This allows us to track the progress of document set
|
||||
and user group syncs.
|
||||
|
||||
This function runs after any task completes (both success and failure)
|
||||
Note that this signal does not fire on a task that failed to complete and is going
|
||||
to be retried.
|
||||
|
||||
This also does not fire if a worker with acks_late=False crashes (which all of our
|
||||
long running workers are)
|
||||
"""
|
||||
if not task:
|
||||
return
|
||||
|
||||
task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}")
|
||||
|
||||
if state not in READY_STATES:
|
||||
return
|
||||
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
|
||||
if not kwargs:
|
||||
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
|
||||
tenant_id = None
|
||||
else:
|
||||
tenant_id = kwargs.get("tenant_id")
|
||||
|
||||
task_logger.debug(
|
||||
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
|
||||
f"{f'for tenant_id={tenant_id}' if tenant_id else ''}"
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
|
||||
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisDocumentSet.PREFIX):
|
||||
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
|
||||
if document_set_id is not None:
|
||||
rds = RedisDocumentSet(tenant_id, int(document_set_id))
|
||||
r.srem(rds.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisUserGroup.PREFIX):
|
||||
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
|
||||
if usergroup_id is not None:
|
||||
rug = RedisUserGroup(tenant_id, int(usergroup_id))
|
||||
r.srem(rug.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorDelete.PREFIX):
|
||||
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
RedisConnectorDelete.remove_from_taskset(int(cc_pair_id), task_id, r)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorPrune.SUBTASK_PREFIX):
|
||||
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
RedisConnectorPrune.remove_from_taskset(int(cc_pair_id), task_id, r)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorPermissionSync.SUBTASK_PREFIX):
|
||||
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
RedisConnectorPermissionSync.remove_from_taskset(
|
||||
int(cc_pair_id), task_id, r
|
||||
)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorExternalGroupSync.SUBTASK_PREFIX):
|
||||
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
RedisConnectorExternalGroupSync.remove_from_taskset(
|
||||
int(cc_pair_id), task_id, r
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
"""The first signal sent on celery worker startup"""
|
||||
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
|
||||
|
||||
|
||||
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for redis to become ready subject to a hardcoded timeout.
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
|
||||
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
ready = False
|
||||
time_start = time.monotonic()
|
||||
logger.info("Redis: Readiness probe starting.")
|
||||
while True:
|
||||
try:
|
||||
if r.ping():
|
||||
ready = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Redis: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
if not ready:
|
||||
msg = (
|
||||
f"Redis: Readiness probe did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
logger.info("Redis: Readiness probe succeeded. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def wait_for_db(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for the db to become ready subject to a hardcoded timeout.
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
ready = False
|
||||
time_start = time.monotonic()
|
||||
logger.info("Database: Readiness probe starting.")
|
||||
while True:
|
||||
try:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
result = db_session.execute(text("SELECT NOW()")).scalar()
|
||||
if result:
|
||||
ready = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Database: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
if not ready:
|
||||
msg = (
|
||||
f"Database: Readiness probe did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
logger.info("Database: Readiness probe succeeded. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def wait_for_vespa(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for Vespa to become ready subject to a hardcoded timeout.
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
ready = False
|
||||
time_start = time.monotonic()
|
||||
logger.info("Vespa: Readiness probe starting.")
|
||||
while True:
|
||||
try:
|
||||
response = requests.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
|
||||
response.raise_for_status()
|
||||
|
||||
response_dict = response.json()
|
||||
if response_dict["status"]["code"] == "up":
|
||||
ready = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Vespa: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
if not ready:
|
||||
msg = (
|
||||
f"Vespa: Readiness probe did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
logger.info("Vespa: Readiness probe succeeded. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("Running as a secondary celery worker.")
|
||||
|
||||
# Set up variables for waiting on primary worker
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
r = get_redis_client(tenant_id=None)
|
||||
time_start = time.monotonic()
|
||||
|
||||
logger.info("Waiting for primary worker to be ready...")
|
||||
while True:
|
||||
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
|
||||
break
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
logger.info(
|
||||
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
msg = (
|
||||
f"Primary worker was not ready within the timeout. "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
logger.info("Wait for primary worker completed successfully. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
task_logger.info("worker_ready signal received.")
|
||||
|
||||
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
if not celery_is_worker_primary(sender):
|
||||
return
|
||||
|
||||
if not sender.primary_worker_lock:
|
||||
return
|
||||
|
||||
logger.info("Releasing primary worker lock.")
|
||||
lock = sender.primary_worker_lock
|
||||
try:
|
||||
if lock.owned():
|
||||
try:
|
||||
lock.release()
|
||||
sender.primary_worker_lock = None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to release primary worker lock: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check if primary worker lock is owned: {e}")
|
||||
|
||||
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
# TODO: could unhardcode format and colorize and accept these as options from
|
||||
# celery's config
|
||||
|
||||
# reformats the root logger
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
root_handler = logging.StreamHandler() # Set up a handler for the root logger
|
||||
root_formatter = ColoredFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
root_handler.setFormatter(root_formatter)
|
||||
root_logger.addHandler(root_handler) # Apply the handler to the root logger
|
||||
|
||||
if logfile:
|
||||
root_file_handler = logging.FileHandler(logfile)
|
||||
root_file_formatter = PlainFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
root_file_handler.setFormatter(root_file_formatter)
|
||||
root_logger.addHandler(root_file_handler)
|
||||
|
||||
root_logger.setLevel(loglevel)
|
||||
|
||||
# reformats celery's task logger
|
||||
task_formatter = CeleryTaskColoredFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
task_handler = logging.StreamHandler() # Set up a handler for the task logger
|
||||
task_handler.setFormatter(task_formatter)
|
||||
task_logger.addHandler(task_handler) # Apply the handler to the task logger
|
||||
|
||||
if logfile:
|
||||
task_file_handler = logging.FileHandler(logfile)
|
||||
task_file_formatter = CeleryTaskPlainFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
task_file_handler.setFormatter(task_file_formatter)
|
||||
task_logger.addHandler(task_file_handler)
|
||||
|
||||
task_logger.setLevel(loglevel)
|
||||
task_logger.propagate = False
|
||||
|
||||
# hide celery task received spam
|
||||
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] received"
|
||||
strategy.logger.setLevel(logging.WARNING)
|
||||
|
||||
# hide celery task succeeded/failed spam
|
||||
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] succeeded in 0.03137450001668185s: None"
|
||||
trace.logger.setLevel(logging.WARNING)
|
||||
@@ -1,172 +0,0 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery.beat import PersistentScheduler # type: ignore
|
||||
from celery.signals import beat_init
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("danswer.background.celery.configs.beat")
|
||||
|
||||
|
||||
class DynamicTenantScheduler(PersistentScheduler):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
logger.info("Initializing DynamicTenantScheduler")
|
||||
super().__init__(*args, **kwargs)
|
||||
self._reload_interval = timedelta(minutes=2)
|
||||
self._last_reload = self.app.now() - self._reload_interval
|
||||
# Let the parent class handle store initialization
|
||||
self.setup_schedule()
|
||||
self._update_tenant_tasks()
|
||||
logger.info(f"Set reload interval to {self._reload_interval}")
|
||||
|
||||
def setup_schedule(self) -> None:
|
||||
logger.info("Setting up initial schedule")
|
||||
super().setup_schedule()
|
||||
logger.info("Initial schedule setup complete")
|
||||
|
||||
def tick(self) -> float:
|
||||
retval = super().tick()
|
||||
now = self.app.now()
|
||||
if (
|
||||
self._last_reload is None
|
||||
or (now - self._last_reload) > self._reload_interval
|
||||
):
|
||||
logger.info("Reload interval reached, initiating tenant task update")
|
||||
self._update_tenant_tasks()
|
||||
self._last_reload = now
|
||||
logger.info("Tenant task update completed, reset reload timer")
|
||||
return retval
|
||||
|
||||
def _update_tenant_tasks(self) -> None:
|
||||
logger.info("Starting tenant task update process")
|
||||
try:
|
||||
logger.info("Fetching all tenant IDs")
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
logger.info(f"Found {len(tenant_ids)} tenants")
|
||||
|
||||
logger.info("Fetching tasks to schedule")
|
||||
tasks_to_schedule = fetch_versioned_implementation(
|
||||
"danswer.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
|
||||
)
|
||||
|
||||
new_beat_schedule: dict[str, dict[str, Any]] = {}
|
||||
|
||||
current_schedule = self.schedule.items()
|
||||
|
||||
existing_tenants = set()
|
||||
for task_name, _ in current_schedule:
|
||||
if "-" in task_name:
|
||||
existing_tenants.add(task_name.split("-")[-1])
|
||||
logger.info(f"Found {len(existing_tenants)} existing tenants in schedule")
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if (
|
||||
IGNORED_SYNCING_TENANT_LIST
|
||||
and tenant_id in IGNORED_SYNCING_TENANT_LIST
|
||||
):
|
||||
logger.info(
|
||||
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
|
||||
)
|
||||
continue
|
||||
|
||||
if tenant_id not in existing_tenants:
|
||||
logger.info(f"Processing new tenant: {tenant_id}")
|
||||
|
||||
for task in tasks_to_schedule():
|
||||
task_name = f"{task['name']}-{tenant_id}"
|
||||
logger.debug(f"Creating task configuration for {task_name}")
|
||||
new_task = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"kwargs": {"tenant_id": tenant_id},
|
||||
}
|
||||
if options := task.get("options"):
|
||||
logger.debug(f"Adding options to task {task_name}: {options}")
|
||||
new_task["options"] = options
|
||||
new_beat_schedule[task_name] = new_task
|
||||
|
||||
if self._should_update_schedule(current_schedule, new_beat_schedule):
|
||||
logger.info(
|
||||
"Schedule update required",
|
||||
extra={
|
||||
"new_tasks": len(new_beat_schedule),
|
||||
"current_tasks": len(current_schedule),
|
||||
},
|
||||
)
|
||||
|
||||
# Create schedule entries
|
||||
entries = {}
|
||||
for name, entry in new_beat_schedule.items():
|
||||
entries[name] = self.Entry(
|
||||
name=name,
|
||||
app=self.app,
|
||||
task=entry["task"],
|
||||
schedule=entry["schedule"],
|
||||
options=entry.get("options", {}),
|
||||
kwargs=entry.get("kwargs", {}),
|
||||
)
|
||||
|
||||
# Update the schedule using the scheduler's methods
|
||||
self.schedule.clear()
|
||||
self.schedule.update(entries)
|
||||
|
||||
# Ensure changes are persisted
|
||||
self.sync()
|
||||
|
||||
logger.info("Schedule update completed successfully")
|
||||
else:
|
||||
logger.info("Schedule is up to date, no changes needed")
|
||||
|
||||
except (AttributeError, KeyError):
|
||||
logger.exception("Failed to process task configuration")
|
||||
except Exception:
|
||||
logger.exception("Unexpected error updating tenant tasks")
|
||||
|
||||
def _should_update_schedule(
|
||||
self, current_schedule: dict, new_schedule: dict
|
||||
) -> bool:
|
||||
"""Compare schedules to determine if an update is needed."""
|
||||
logger.debug("Comparing current and new schedules")
|
||||
current_tasks = set(name for name, _ in current_schedule)
|
||||
new_tasks = set(new_schedule.keys())
|
||||
needs_update = current_tasks != new_tasks
|
||||
logger.debug(f"Schedule update needed: {needs_update}")
|
||||
return needs_update
|
||||
|
||||
|
||||
@beat_init.connect
|
||||
def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("beat_init signal received.")
|
||||
|
||||
# Celery beat shouldn't touch the db at all. But just setting a low minimum here.
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=2, max_overflow=0)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
celery_app.conf.beat_scheduler = DynamicTenantScheduler
|
||||
@@ -1,97 +0,0 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("danswer.background.celery.configs.heavy")
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=4, max_overflow=12)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"danswer.background.celery.tasks.pruning",
|
||||
"danswer.background.celery.tasks.doc_permission_syncing",
|
||||
"danswer.background.celery.tasks.external_group_syncing",
|
||||
]
|
||||
)
|
||||
@@ -1,101 +0,0 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_process_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("danswer.background.celery.configs.indexing")
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None:
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"danswer.background.celery.tasks.indexing",
|
||||
]
|
||||
)
|
||||
@@ -1,97 +0,0 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("danswer.background.celery.configs.light")
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"danswer.background.celery.tasks.shared",
|
||||
"danswer.background.celery.tasks.vespa",
|
||||
"danswer.background.celery.tasks.connector_deletion",
|
||||
"danswer.background.celery.tasks.doc_permission_syncing",
|
||||
]
|
||||
)
|
||||
@@ -1,248 +0,0 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from celery import bootsteps # type: ignore
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndex
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from danswer.redis.redis_connector_stop import RedisConnectorStop
|
||||
from danswer.redis.redis_document_set import RedisDocumentSet
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.redis.redis_usergroup import RedisUserGroup
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("danswer.background.celery.configs.primary")
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
# by the primary worker. This is unnecessary in the multi tenant scenario
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
# For the moment, we're assuming that we are the only primary worker
|
||||
# that should be running.
|
||||
# TODO: maybe check for or clean up another zombie primary worker if we detect it
|
||||
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
|
||||
|
||||
# this process wide lock is taken to help other workers start up in order.
|
||||
# it is planned to use this lock to enforce singleton behavior on the primary
|
||||
# worker, since the primary worker does redis cleanup on startup, but this isn't
|
||||
# implemented yet.
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
logger.info("Primary worker lock: Acquire starting.")
|
||||
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
|
||||
if acquired:
|
||||
logger.info("Primary worker lock: Acquire succeeded.")
|
||||
else:
|
||||
logger.error("Primary worker lock: Acquire failed!")
|
||||
raise WorkerShutdown("Primary worker lock could not be acquired!")
|
||||
|
||||
# tacking on our own user data to the sender
|
||||
sender.primary_worker_lock = lock
|
||||
|
||||
# As currently designed, when this worker starts as "primary", we reinitialize redis
|
||||
# to a clean state (for our purposes, anyway)
|
||||
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
|
||||
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||
|
||||
RedisDocumentSet.reset_all(r)
|
||||
|
||||
RedisUserGroup.reset_all(r)
|
||||
|
||||
RedisConnectorDelete.reset_all(r)
|
||||
|
||||
RedisConnectorPrune.reset_all(r)
|
||||
|
||||
RedisConnectorIndex.reset_all(r)
|
||||
|
||||
RedisConnectorStop.reset_all(r)
|
||||
|
||||
RedisConnectorPermissionSync.reset_all(r)
|
||||
|
||||
RedisConnectorExternalGroupSync.reset_all(r)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
"""Regularly reacquires the primary worker lock outside of the task queue.
|
||||
Use the task_logger in this class to avoid double logging.
|
||||
|
||||
This cannot be done inside a regular beat task because it must run on schedule and
|
||||
a queue of existing work would starve the task from running.
|
||||
"""
|
||||
|
||||
# it's unclear to me whether using the hub's timer or the bootstep timer is better
|
||||
requires = {"celery.worker.components:Hub"}
|
||||
|
||||
def __init__(self, worker: Any, **kwargs: Any) -> None:
|
||||
self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds
|
||||
self.task_tref = None
|
||||
|
||||
def start(self, worker: Any) -> None:
|
||||
if not celery_is_worker_primary(worker):
|
||||
return
|
||||
|
||||
# Access the worker's event loop (hub)
|
||||
hub = worker.consumer.controller.hub
|
||||
|
||||
# Schedule the periodic task
|
||||
self.task_tref = hub.call_repeatedly(
|
||||
self.interval, self.run_periodic_task, worker
|
||||
)
|
||||
task_logger.info("Scheduled periodic task with hub.")
|
||||
|
||||
def run_periodic_task(self, worker: Any) -> None:
|
||||
try:
|
||||
if not celery_is_worker_primary(worker):
|
||||
return
|
||||
|
||||
if not hasattr(worker, "primary_worker_lock"):
|
||||
return
|
||||
|
||||
lock = worker.primary_worker_lock
|
||||
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
if lock.owned():
|
||||
task_logger.debug("Reacquiring primary worker lock.")
|
||||
lock.reacquire()
|
||||
else:
|
||||
task_logger.warning(
|
||||
"Full acquisition of primary worker lock. "
|
||||
"Reasons could be worker restart or lock expiration."
|
||||
)
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
task_logger.info("Primary worker lock: Acquire starting.")
|
||||
acquired = lock.acquire(
|
||||
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
|
||||
)
|
||||
if acquired:
|
||||
task_logger.info("Primary worker lock: Acquire succeeded.")
|
||||
worker.primary_worker_lock = lock
|
||||
else:
|
||||
task_logger.error("Primary worker lock: Acquire failed!")
|
||||
raise TimeoutError("Primary worker lock could not be acquired!")
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Periodic task failed.")
|
||||
|
||||
def stop(self, worker: Any) -> None:
|
||||
# Cancel the scheduled task when the worker stops
|
||||
if self.task_tref:
|
||||
self.task_tref.cancel()
|
||||
task_logger.info("Canceled periodic task with hub.")
|
||||
|
||||
|
||||
celery_app.steps["worker"].add(HubPeriodicTask)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"danswer.background.celery.tasks.connector_deletion",
|
||||
"danswer.background.celery.tasks.indexing",
|
||||
"danswer.background.celery.tasks.periodic",
|
||||
"danswer.background.celery.tasks.doc_permission_syncing",
|
||||
"danswer.background.celery.tasks.external_group_syncing",
|
||||
"danswer.background.celery.tasks.pruning",
|
||||
"danswer.background.celery.tasks.shared",
|
||||
"danswer.background.celery.tasks.vespa",
|
||||
]
|
||||
)
|
||||
@@ -1,26 +0,0 @@
|
||||
import logging
|
||||
|
||||
from celery import current_task
|
||||
|
||||
from danswer.utils.logger import ColoredFormatter
|
||||
from danswer.utils.logger import PlainFormatter
|
||||
|
||||
|
||||
class CeleryTaskPlainFormatter(PlainFormatter):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
task = current_task
|
||||
if task and task.request:
|
||||
record.__dict__.update(task_id=task.request.id, task_name=task.name)
|
||||
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
|
||||
|
||||
return super().format(record)
|
||||
|
||||
|
||||
class CeleryTaskColoredFormatter(ColoredFormatter):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
task = current_task
|
||||
if task and task.request:
|
||||
record.__dict__.update(task_id=task.request.id, task_name=task.name)
|
||||
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
|
||||
|
||||
return super().format(record)
|
||||
484
backend/danswer/background/celery/celery_app.py
Normal file
484
backend/danswer/background/celery/celery_app.py
Normal file
@@ -0,0 +1,484 @@
|
||||
import logging
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
import redis
|
||||
from celery import bootsteps # type: ignore
|
||||
from celery import Celery
|
||||
from celery import current_task
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.signals import beat_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
from celery.states import READY_STATES
|
||||
from celery.utils.log import get_task_logger
|
||||
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import ColoredFormatter
|
||||
from danswer.utils.logger import PlainFormatter
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# use this within celery tasks to get celery task specific logging
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object(
|
||||
"danswer.background.celery.celeryconfig"
|
||||
) # Load configuration from 'celeryconfig.py'
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def celery_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
"""We handle this signal in order to remove completed tasks
|
||||
from their respective tasksets. This allows us to track the progress of document set
|
||||
and user group syncs.
|
||||
|
||||
This function runs after any task completes (both success and failure)
|
||||
Note that this signal does not fire on a task that failed to complete and is going
|
||||
to be retried.
|
||||
"""
|
||||
if not task:
|
||||
return
|
||||
|
||||
task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}")
|
||||
# logger.debug(f"Result: {retval}")
|
||||
|
||||
if state not in READY_STATES:
|
||||
return
|
||||
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
|
||||
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisDocumentSet.PREFIX):
|
||||
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
|
||||
if document_set_id is not None:
|
||||
rds = RedisDocumentSet(document_set_id)
|
||||
r.srem(rds.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisUserGroup.PREFIX):
|
||||
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
|
||||
if usergroup_id is not None:
|
||||
rug = RedisUserGroup(usergroup_id)
|
||||
r.srem(rug.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorDeletion.PREFIX):
|
||||
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
r.srem(rcd.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
|
||||
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
rcp = RedisConnectorPruning(cc_pair_id)
|
||||
r.srem(rcp.taskset_key, task_id)
|
||||
return
|
||||
|
||||
|
||||
@beat_init.connect
|
||||
def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=2, max_overflow=0)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
# decide some initial startup settings based on the celery worker's hostname
|
||||
# (set at the command line)
|
||||
hostname = sender.hostname
|
||||
if hostname.startswith("light"):
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
|
||||
elif hostname.startswith("heavy"):
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
else:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
time_start = time.monotonic()
|
||||
logger.info("Redis: Readiness check starting.")
|
||||
while True:
|
||||
try:
|
||||
if r.ping():
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
logger.info(
|
||||
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
msg = (
|
||||
f"Redis: Readiness check did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
logger.info("Redis: Readiness check succeeded. Continuing...")
|
||||
|
||||
if not celery_is_worker_primary(sender):
|
||||
logger.info("Running as a secondary celery worker.")
|
||||
logger.info("Waiting for primary worker to be ready...")
|
||||
time_start = time.monotonic()
|
||||
while True:
|
||||
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
|
||||
break
|
||||
|
||||
time.monotonic()
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
logger.info(
|
||||
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
msg = (
|
||||
f"Primary worker was not ready within the timeout. "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
logger.info("Wait for primary worker completed successfully. Continuing...")
|
||||
return
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
# by the primary worker
|
||||
r = get_redis_client()
|
||||
|
||||
# For the moment, we're assuming that we are the only primary worker
|
||||
# that should be running.
|
||||
# TODO: maybe check for or clean up another zombie primary worker if we detect it
|
||||
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
|
||||
|
||||
# this process wide lock is taken to help other workers start up in order.
|
||||
# it is planned to use this lock to enforce singleton behavior on the primary
|
||||
# worker, since the primary worker does redis cleanup on startup, but this isn't
|
||||
# implemented yet.
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
logger.info("Primary worker lock: Acquire starting.")
|
||||
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
|
||||
if acquired:
|
||||
logger.info("Primary worker lock: Acquire succeeded.")
|
||||
else:
|
||||
logger.error("Primary worker lock: Acquire failed!")
|
||||
raise WorkerShutdown("Primary worker lock could not be acquired!")
|
||||
|
||||
sender.primary_worker_lock = lock
|
||||
|
||||
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
|
||||
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||
|
||||
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
task_logger.info("worker_ready signal received.")
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
if not celery_is_worker_primary(sender):
|
||||
return
|
||||
|
||||
if not sender.primary_worker_lock:
|
||||
return
|
||||
|
||||
logger.info("Releasing primary worker lock.")
|
||||
lock = sender.primary_worker_lock
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
sender.primary_worker_lock = None
|
||||
|
||||
|
||||
class CeleryTaskPlainFormatter(PlainFormatter):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
task = current_task
|
||||
if task and task.request:
|
||||
record.__dict__.update(task_id=task.request.id, task_name=task.name)
|
||||
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
|
||||
|
||||
return super().format(record)
|
||||
|
||||
|
||||
class CeleryTaskColoredFormatter(ColoredFormatter):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
task = current_task
|
||||
if task and task.request:
|
||||
record.__dict__.update(task_id=task.request.id, task_name=task.name)
|
||||
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
|
||||
|
||||
return super().format(record)
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
# TODO: could unhardcode format and colorize and accept these as options from
|
||||
# celery's config
|
||||
|
||||
# reformats celery's worker logger
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
root_handler = logging.StreamHandler() # Set up a handler for the root logger
|
||||
root_formatter = ColoredFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
root_handler.setFormatter(root_formatter)
|
||||
root_logger.addHandler(root_handler) # Apply the handler to the root logger
|
||||
|
||||
if logfile:
|
||||
root_file_handler = logging.FileHandler(logfile)
|
||||
root_file_formatter = PlainFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
root_file_handler.setFormatter(root_file_formatter)
|
||||
root_logger.addHandler(root_file_handler)
|
||||
|
||||
root_logger.setLevel(loglevel)
|
||||
|
||||
# reformats celery's task logger
|
||||
task_formatter = CeleryTaskColoredFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
task_handler = logging.StreamHandler() # Set up a handler for the task logger
|
||||
task_handler.setFormatter(task_formatter)
|
||||
task_logger.addHandler(task_handler) # Apply the handler to the task logger
|
||||
|
||||
if logfile:
|
||||
task_file_handler = logging.FileHandler(logfile)
|
||||
task_file_formatter = CeleryTaskPlainFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
task_file_handler.setFormatter(task_file_formatter)
|
||||
task_logger.addHandler(task_file_handler)
|
||||
|
||||
task_logger.setLevel(loglevel)
|
||||
task_logger.propagate = False
|
||||
|
||||
|
||||
class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
"""Regularly reacquires the primary worker lock outside of the task queue.
|
||||
Use the task_logger in this class to avoid double logging.
|
||||
|
||||
This cannot be done inside a regular beat task because it must run on schedule and
|
||||
a queue of existing work would starve the task from running.
|
||||
"""
|
||||
|
||||
# it's unclear to me whether using the hub's timer or the bootstep timer is better
|
||||
requires = {"celery.worker.components:Hub"}
|
||||
|
||||
def __init__(self, worker: Any, **kwargs: Any) -> None:
|
||||
self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds
|
||||
self.task_tref = None
|
||||
|
||||
def start(self, worker: Any) -> None:
|
||||
if not celery_is_worker_primary(worker):
|
||||
return
|
||||
|
||||
# Access the worker's event loop (hub)
|
||||
hub = worker.consumer.controller.hub
|
||||
|
||||
# Schedule the periodic task
|
||||
self.task_tref = hub.call_repeatedly(
|
||||
self.interval, self.run_periodic_task, worker
|
||||
)
|
||||
task_logger.info("Scheduled periodic task with hub.")
|
||||
|
||||
def run_periodic_task(self, worker: Any) -> None:
|
||||
try:
|
||||
if not worker.primary_worker_lock:
|
||||
return
|
||||
|
||||
if not hasattr(worker, "primary_worker_lock"):
|
||||
return
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
lock: redis.lock.Lock = worker.primary_worker_lock
|
||||
|
||||
if lock.owned():
|
||||
task_logger.debug("Reacquiring primary worker lock.")
|
||||
lock.reacquire()
|
||||
else:
|
||||
task_logger.warning(
|
||||
"Full acquisition of primary worker lock. "
|
||||
"Reasons could be computer sleep or a clock change."
|
||||
)
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
task_logger.info("Primary worker lock: Acquire starting.")
|
||||
acquired = lock.acquire(
|
||||
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
|
||||
)
|
||||
if acquired:
|
||||
task_logger.info("Primary worker lock: Acquire succeeded.")
|
||||
else:
|
||||
task_logger.error("Primary worker lock: Acquire failed!")
|
||||
raise TimeoutError("Primary worker lock could not be acquired!")
|
||||
|
||||
worker.primary_worker_lock = lock
|
||||
except Exception:
|
||||
task_logger.exception("HubPeriodicTask.run_periodic_task exceptioned.")
|
||||
|
||||
def stop(self, worker: Any) -> None:
|
||||
# Cancel the scheduled task when the worker stops
|
||||
if self.task_tref:
|
||||
self.task_tref.cancel()
|
||||
task_logger.info("Canceled periodic task with hub.")
|
||||
|
||||
|
||||
celery_app.steps["worker"].add(HubPeriodicTask)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"danswer.background.celery.tasks.connector_deletion",
|
||||
"danswer.background.celery.tasks.periodic",
|
||||
"danswer.background.celery.tasks.pruning",
|
||||
"danswer.background.celery.tasks.shared",
|
||||
"danswer.background.celery.tasks.vespa",
|
||||
]
|
||||
)
|
||||
|
||||
#####
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
celery_app.conf.beat_schedule = {
|
||||
"check-for-vespa-sync": {
|
||||
"task": "check_for_vespa_sync_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
}
|
||||
celery_app.conf.beat_schedule.update(
|
||||
{
|
||||
"check-for-connector-deletion-task": {
|
||||
"task": "check_for_connector_deletion_task",
|
||||
# don't need to check too often, since we kick off a deletion initially
|
||||
# during the API call that actually marks the CC pair for deletion
|
||||
"schedule": timedelta(seconds=60),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
}
|
||||
)
|
||||
celery_app.conf.beat_schedule.update(
|
||||
{
|
||||
"check-for-prune": {
|
||||
"task": "check_for_prune_task_2",
|
||||
"schedule": timedelta(seconds=60),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
}
|
||||
)
|
||||
celery_app.conf.beat_schedule.update(
|
||||
{
|
||||
"kombu-message-cleanup": {
|
||||
"task": "kombu_message_cleanup_task",
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
||||
},
|
||||
}
|
||||
)
|
||||
celery_app.conf.beat_schedule.update(
|
||||
{
|
||||
"monitor-vespa-sync": {
|
||||
"task": "monitor_vespa_sync",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
}
|
||||
)
|
||||
@@ -1,10 +1,465 @@
|
||||
# These are helper objects for tracking the keys we need to write in redis
|
||||
import time
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.configs.base import CELERY_SEPARATOR
|
||||
from danswer.background.celery.celeryconfig import CELERY_SEPARATOR
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import construct_document_select_for_connector_credential_pair
|
||||
from danswer.db.document import (
|
||||
construct_document_select_for_connector_credential_pair_by_needs_sync,
|
||||
)
|
||||
from danswer.db.document_set import construct_document_select_by_docset
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
class RedisObjectHelper(ABC):
|
||||
PREFIX = "base"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: int):
|
||||
self._id: int = id
|
||||
|
||||
@property
|
||||
def task_id_prefix(self) -> str:
|
||||
return f"{self.PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def fence_key(self) -> str:
|
||||
# example: documentset_fence_1
|
||||
return f"{self.FENCE_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def taskset_key(self) -> str:
|
||||
# example: documentset_taskset_1
|
||||
return f"{self.TASKSET_PREFIX}_{self._id}"
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_fence_key(key: str) -> int | None:
|
||||
"""
|
||||
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
|
||||
|
||||
Args:
|
||||
key (str): The fence key string.
|
||||
|
||||
Returns:
|
||||
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
|
||||
"""
|
||||
parts = key.split("_")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
try:
|
||||
object_id = int(parts[2])
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
return object_id
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_task_id(task_id: str) -> int | None:
|
||||
"""
|
||||
Extracts the object ID from a task ID string.
|
||||
|
||||
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
|
||||
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
|
||||
- `objectid` is the ID you want to extract,
|
||||
- `suffix` is another arbitrary string (e.g., a UUID).
|
||||
|
||||
Example:
|
||||
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
|
||||
this method will return the string `"1"`.
|
||||
|
||||
Args:
|
||||
task_id (str): The task ID string from which to extract the object ID.
|
||||
|
||||
Returns:
|
||||
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
|
||||
"""
|
||||
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
|
||||
parts = task_id.split("_")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
try:
|
||||
object_id = int(parts[1])
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
return object_id
|
||||
|
||||
@abstractmethod
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
pass
|
||||
|
||||
|
||||
class RedisDocumentSet(RedisObjectHelper):
|
||||
PREFIX = "documentset"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
stmt = construct_document_select_by_docset(self._id, current_only=False)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisUserGroup(RedisObjectHelper):
|
||||
PREFIX = "usergroup"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
|
||||
try:
|
||||
construct_document_select_by_usergroup = fetch_versioned_implementation(
|
||||
"danswer.db.user_group",
|
||||
"construct_document_select_by_usergroup",
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
return 0
|
||||
|
||||
stmt = construct_document_select_by_usergroup(self._id)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
"""This class differs from the default in that the taskset used spans
|
||||
all connectors and is not per connector."""
|
||||
|
||||
PREFIX = "connectorsync"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
@classmethod
|
||||
def get_fence_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.FENCE_PREFIX
|
||||
|
||||
@classmethod
|
||||
def get_taskset_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.TASKSET_PREFIX
|
||||
|
||||
@property
|
||||
def taskset_key(self) -> str:
|
||||
"""Notice that this is intentionally reusing the same taskset for all
|
||||
connector syncs"""
|
||||
# example: connector_taskset
|
||||
return f"{self.TASKSET_PREFIX}"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(
|
||||
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
|
||||
)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorDeletion(RedisObjectHelper):
|
||||
PREFIX = "connectordeletion"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
stmt = construct_document_select_for_connector_credential_pair(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
kwargs=dict(
|
||||
document_id=doc.id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorPruning(RedisObjectHelper):
|
||||
"""Celery will kick off a long running generator task to crawl the connector and
|
||||
find any missing docs, which will each then get a new cleanup task. The progress of
|
||||
those tasks will then be monitored to completion.
|
||||
|
||||
Example rough happy path order:
|
||||
Check connectorpruning_fence_1
|
||||
Send generator task with id connectorpruning+generator_1_{uuid}
|
||||
|
||||
generator runs connector with callbacks that increment connectorpruning_generator_progress_1
|
||||
generator creates many subtasks with id connectorpruning+sub_1_{uuid}
|
||||
in taskset connectorpruning_taskset_1
|
||||
on completion, generator sets connectorpruning_generator_complete_1
|
||||
|
||||
celery postrun removes subtasks from taskset
|
||||
monitor beat task cleans up when taskset reaches 0 items
|
||||
"""
|
||||
|
||||
PREFIX = "connectorpruning"
|
||||
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire pruning process
|
||||
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
|
||||
|
||||
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
|
||||
SUBTASK_PREFIX = PREFIX + "+sub"
|
||||
|
||||
GENERATOR_PROGRESS_PREFIX = (
|
||||
PREFIX + "_generator_progress"
|
||||
) # a signal that contains generator progress
|
||||
GENERATOR_COMPLETE_PREFIX = (
|
||||
PREFIX + "_generator_complete"
|
||||
) # a signal that the generator has finished
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
"""id: the cc_pair_id of the connector credential pair"""
|
||||
|
||||
super().__init__(id)
|
||||
self.documents_to_prune: set[str] = set()
|
||||
|
||||
@property
|
||||
def generator_task_id_prefix(self) -> str:
|
||||
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_progress_key(self) -> str:
|
||||
# example: connectorpruning_generator_progress_1
|
||||
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_complete_key(self) -> str:
|
||||
# example: connectorpruning_generator_complete_1
|
||||
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def subtask_id_prefix(self) -> str:
|
||||
return f"{self.SUBTASK_PREFIX}_{self._id}"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
for doc_id in self.documents_to_prune:
|
||||
current_time = time.monotonic()
|
||||
if lock and current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.subtask_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
kwargs=dict(
|
||||
document_id=doc_id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
def is_pruning(self, db_session: Session, redis_client: Redis) -> bool:
|
||||
"""A single example of a helper method being refactored into the redis helper"""
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=self._id, db_session=db_session
|
||||
)
|
||||
if not cc_pair:
|
||||
raise ValueError(f"cc_pair_id {self._id} does not exist.")
|
||||
|
||||
if redis_client.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Factory stub for running celery worker / celery beat."""
|
||||
"""Entry point for running celery worker / celery beat."""
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app = fetch_versioned_implementation(
|
||||
"danswer.background.celery.apps.primary", "celery_app"
|
||||
celery_app = fetch_versioned_implementation(
|
||||
"danswer.background.celery.celery_app", "celery_app"
|
||||
)
|
||||
@@ -1,23 +1,24 @@
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from danswer.connectors.interfaces import BaseConnector
|
||||
from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SlimConnector
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.enums import TaskStatus
|
||||
from danswer.db.models import TaskQueueState
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.documents.models import DeletionAttemptSnapshot
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -26,10 +27,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def _get_deletion_status(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
tenant_id: str | None = None,
|
||||
connector_id: int, credential_id: int, db_session: Session
|
||||
) -> TaskQueueState | None:
|
||||
"""We no longer store TaskQueueState in the DB for a deletion attempt.
|
||||
This function populates TaskQueueState by just checking redis.
|
||||
@@ -40,26 +38,21 @@ def _get_deletion_status(
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair.id)
|
||||
if not redis_connector.delete.fenced:
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
|
||||
r = get_redis_client()
|
||||
if not r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
return TaskQueueState(
|
||||
task_id="",
|
||||
task_name=redis_connector.delete.fence_key,
|
||||
status=TaskStatus.STARTED,
|
||||
task_id="", task_name=rcd.fence_key, status=TaskStatus.STARTED
|
||||
)
|
||||
|
||||
|
||||
def get_deletion_attempt_snapshot(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
tenant_id: str | None = None,
|
||||
connector_id: int, credential_id: int, db_session: Session
|
||||
) -> DeletionAttemptSnapshot | None:
|
||||
deletion_task = _get_deletion_status(
|
||||
connector_id, credential_id, db_session, tenant_id
|
||||
)
|
||||
deletion_task = _get_deletion_status(connector_id, credential_id, db_session)
|
||||
if not deletion_task:
|
||||
return None
|
||||
|
||||
@@ -70,31 +63,26 @@ def get_deletion_attempt_snapshot(
|
||||
)
|
||||
|
||||
|
||||
def document_batch_to_ids(
|
||||
doc_batch: list[Document],
|
||||
) -> set[str]:
|
||||
def document_batch_to_ids(doc_batch: list[Document]) -> set[str]:
|
||||
return {doc.id for doc in doc_batch}
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(
|
||||
runnable_connector: BaseConnector,
|
||||
callback: RunIndexingCallbackInterface | None = None,
|
||||
progress_callback: Callable[[int], None] | None = None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
If the SlimConnector hasnt been implemented for the given connector, just pull
|
||||
If the PruneConnector hasnt been implemented for the given connector, just pull
|
||||
all docs using the load_from_state and grab out the IDs.
|
||||
|
||||
Optionally, a callback can be passed to handle the length of each document batch.
|
||||
"""
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
|
||||
if isinstance(runnable_connector, SlimConnector):
|
||||
for metadata_batch in runnable_connector.retrieve_all_slim_documents():
|
||||
all_connector_doc_ids.update({doc.id for doc in metadata_batch})
|
||||
|
||||
doc_batch_generator = None
|
||||
|
||||
if isinstance(runnable_connector, LoadConnector):
|
||||
if isinstance(runnable_connector, IdConnector):
|
||||
all_connector_doc_ids = runnable_connector.retrieve_all_source_ids()
|
||||
elif isinstance(runnable_connector, LoadConnector):
|
||||
doc_batch_generator = runnable_connector.load_from_state()
|
||||
elif isinstance(runnable_connector, PollConnector):
|
||||
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
@@ -103,17 +91,16 @@ def extract_ids_from_runnable_connector(
|
||||
else:
|
||||
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
|
||||
|
||||
doc_batch_processing_func = document_batch_to_ids
|
||||
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
|
||||
doc_batch_processing_func = rate_limit_builder(
|
||||
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
|
||||
)(document_batch_to_ids)
|
||||
for doc_batch in doc_batch_generator:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("Stop signal received")
|
||||
callback.progress(len(doc_batch))
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
|
||||
if doc_batch_generator:
|
||||
doc_batch_processing_func = document_batch_to_ids
|
||||
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
|
||||
doc_batch_processing_func = rate_limit_builder(
|
||||
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
|
||||
)(document_batch_to_ids)
|
||||
for doc_batch in doc_batch_generator:
|
||||
if progress_callback:
|
||||
progress_callback(len(doc_batch))
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
|
||||
|
||||
return all_connector_doc_ids
|
||||
|
||||
@@ -134,10 +121,13 @@ def celery_is_listening_to_queue(worker: Any, name: str) -> bool:
|
||||
def celery_is_worker_primary(worker: Any) -> bool:
|
||||
"""There are multiple approaches that could be taken to determine if a celery worker
|
||||
is 'primary', as defined by us. But the way we do it is to check the hostname set
|
||||
for the celery worker, which can be done on the
|
||||
for the celery worker, which can be done either in celeryconfig.py or on the
|
||||
command line with '--hostname'."""
|
||||
hostname = worker.hostname
|
||||
if hostname.startswith("primary"):
|
||||
return True
|
||||
if hostname.startswith("light"):
|
||||
return False
|
||||
|
||||
return False
|
||||
if hostname.startswith("heavy"):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
|
||||
import urllib.parse
|
||||
|
||||
from danswer.configs.app_configs import CELERY_BROKER_POOL_LIMIT
|
||||
from danswer.configs.app_configs import CELERY_RESULT_EXPIRES
|
||||
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY
|
||||
@@ -19,7 +17,7 @@ CELERY_SEPARATOR = ":"
|
||||
|
||||
CELERY_PASSWORD_PART = ""
|
||||
if REDIS_PASSWORD:
|
||||
CELERY_PASSWORD_PART = ":" + urllib.parse.quote(REDIS_PASSWORD, safe="") + "@"
|
||||
CELERY_PASSWORD_PART = f":{REDIS_PASSWORD}@"
|
||||
|
||||
REDIS_SCHEME = "redis"
|
||||
|
||||
@@ -31,10 +29,16 @@ if REDIS_SSL:
|
||||
if REDIS_SSL_CA_CERTS:
|
||||
SSL_QUERY_PARAMS += f"&ssl_ca_certs={REDIS_SSL_CA_CERTS}"
|
||||
|
||||
# region Broker settings
|
||||
# example celery_broker_url: "redis://:password@localhost:6379/15"
|
||||
broker_url = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}"
|
||||
|
||||
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY_RESULT_BACKEND}{SSL_QUERY_PARAMS}"
|
||||
|
||||
# NOTE: prefetch 4 is significantly faster than prefetch 1 for small tasks
|
||||
# however, prefetching is bad when tasks are lengthy as those tasks
|
||||
# can stall other tasks.
|
||||
worker_prefetch_multiplier = 4
|
||||
|
||||
broker_connection_retry_on_startup = True
|
||||
broker_pool_limit = CELERY_BROKER_POOL_LIMIT
|
||||
|
||||
@@ -49,7 +53,6 @@ broker_transport_options = {
|
||||
"socket_keepalive": True,
|
||||
"socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS,
|
||||
}
|
||||
# endregion
|
||||
|
||||
# redis backend settings
|
||||
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
|
||||
@@ -63,19 +66,10 @@ redis_backend_health_check_interval = REDIS_HEALTH_CHECK_INTERVAL
|
||||
task_default_priority = DanswerCeleryPriority.MEDIUM
|
||||
task_acks_late = True
|
||||
|
||||
# region Task result backend settings
|
||||
# It's possible we don't even need celery's result backend, in which case all of the optimization below
|
||||
# might be irrelevant
|
||||
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY_RESULT_BACKEND}{SSL_QUERY_PARAMS}"
|
||||
result_expires = CELERY_RESULT_EXPIRES # 86400 seconds is the default
|
||||
# endregion
|
||||
|
||||
# Leaving this to the default of True may cause double logging since both our own app
|
||||
# and celery think they are controlling the logger.
|
||||
# TODO: Configure celery's logger entirely manually and set this to False
|
||||
# worker_hijack_root_logger = False
|
||||
|
||||
# region Notes on serialization performance
|
||||
# Option 0: Defaults (json serializer, no compression)
|
||||
# about 1.5 KB per queued task. 1KB in queue, 400B for result, 100 as a child entry in generator result
|
||||
|
||||
@@ -101,4 +95,3 @@ result_expires = CELERY_RESULT_EXPIRES # 86400 seconds is the default
|
||||
# task_serializer = "pickle-bzip2"
|
||||
# result_serializer = "pickle-bzip2"
|
||||
# accept_content=["pickle", "pickle-bzip2"]
|
||||
# endregion
|
||||
@@ -1,14 +0,0 @@
|
||||
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
|
||||
import danswer.background.celery.configs.base as shared_config
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
@@ -1,20 +0,0 @@
|
||||
import danswer.background.celery.configs.base as shared_config
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = 4
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
@@ -1,21 +0,0 @@
|
||||
import danswer.background.celery.configs.base as shared_config
|
||||
from danswer.configs.app_configs import CELERY_WORKER_INDEXING_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = CELERY_WORKER_INDEXING_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
@@ -1,22 +0,0 @@
|
||||
import danswer.background.celery.configs.base as shared_config
|
||||
from danswer.configs.app_configs import CELERY_WORKER_LIGHT_CONCURRENCY
|
||||
from danswer.configs.app_configs import CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = CELERY_WORKER_LIGHT_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER
|
||||
@@ -1,20 +0,0 @@
|
||||
import danswer.background.celery.configs.base as shared_config
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = 4
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
@@ -1,60 +0,0 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
|
||||
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": "check_for_vespa_sync_task",
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": "check_for_connector_deletion_task",
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": "check_for_indexing",
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": "check_for_pruning",
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": "kombu_message_cleanup_task",
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": "monitor_vespa_sync",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": "check_for_doc_permissions_sync",
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": "check_for_external_group_sync",
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return tasks_to_schedule
|
||||
@@ -1,41 +1,30 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.search_settings import get_all_search_settings
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDeletionFenceData
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
|
||||
|
||||
class TaskDependencyError(RuntimeError):
|
||||
"""Raised to the caller to indicate dependent tasks are running that would interfere
|
||||
with connector deletion."""
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_connector_deletion_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
def check_for_connector_deletion_task() -> None:
|
||||
r = get_redis_client()
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
@@ -47,141 +36,75 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
# collect cc_pair_ids
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair.id)
|
||||
|
||||
# try running cleanup on the cc_pair_ids
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
try:
|
||||
try_generate_document_cc_pair_cleanup_tasks(
|
||||
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
except TaskDependencyError as e:
|
||||
# this means we wanted to start deleting but dependent tasks were running
|
||||
# Leave a stop signal to clear indexing and pruning tasks more quickly
|
||||
task_logger.info(str(e))
|
||||
redis_connector.stop.set_fence(True)
|
||||
else:
|
||||
# clear the stop signal if it exists ... no longer needed
|
||||
redis_connector.stop.set_fence(False)
|
||||
|
||||
try_generate_document_cc_pair_cleanup_tasks(
|
||||
cc_pair, db_session, r, lock_beat
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
task_logger.exception("Unexpected exception")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_generate_document_cc_pair_cleanup_tasks(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Note that syncing can still be required even if the number of sync tasks generated is zero.
|
||||
Returns None if no syncing is required.
|
||||
|
||||
Will raise TaskDependencyError if dependent tasks such as indexing and pruning are
|
||||
still running. In our case, the caller reacts by setting a stop signal in Redis to
|
||||
exit those tasks as quickly as possible.
|
||||
"""
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
if redis_connector.delete.fenced:
|
||||
if r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
# we need to load the state of the object inside the fence
|
||||
# we need to refresh the state of the object inside the fence
|
||||
# to avoid a race condition with db.commit/fence deletion
|
||||
# at the end of this taskset
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if not cc_pair:
|
||||
try:
|
||||
db_session.refresh(cc_pair)
|
||||
except ObjectDeletedError:
|
||||
return None
|
||||
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# set a basic fence to start
|
||||
fence_payload = RedisConnectorDeletionFenceData(
|
||||
num_tasks=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rcd.taskset_key)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
||||
)
|
||||
tasks_generated = rcd.generate_tasks(celery_app, db_session, r, lock_beat)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks finished. "
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
redis_connector.delete.set_fence(fence_payload)
|
||||
|
||||
try:
|
||||
# do not proceed if connector indexing or connector pruning are running
|
||||
search_settings_list = get_all_search_settings(db_session)
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
if redis_connector_index.fenced:
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (indexing in progress): "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
|
||||
if redis_connector.prune.fenced:
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (pruning in progress): "
|
||||
f"cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
if redis_connector.permissions.fenced:
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (permissions in progress): "
|
||||
f"cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
redis_connector.delete.taskset_clear()
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = redis_connector.delete.generate_tasks(
|
||||
app, db_session, lock_beat
|
||||
)
|
||||
if tasks_generated is None:
|
||||
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
|
||||
except TaskDependencyError:
|
||||
redis_connector.delete.set_fence(None)
|
||||
raise
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception")
|
||||
redis_connector.delete.set_fence(None)
|
||||
return None
|
||||
else:
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
fence_payload.num_tasks = tasks_generated
|
||||
redis_connector.delete.set_fence(fence_payload)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rcd.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
||||
@@ -1,321 +0,0 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
|
||||
from danswer.access.models import DocExternalAccess
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_doc_perm_sync import (
|
||||
RedisConnectorPermissionSyncData,
|
||||
)
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import doc_permission_sync_ctx
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.danswer.db.document import upsert_document_external_perms
|
||||
from ee.danswer.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
|
||||
from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES = 3
|
||||
|
||||
|
||||
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
|
||||
LIGHT_SOFT_TIME_LIMIT = 105
|
||||
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
|
||||
|
||||
def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if external doc permissions sync is due."""
|
||||
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
# skip doc permissions sync if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return False
|
||||
|
||||
# If the last sync is None, it has never been run so we run the sync
|
||||
last_perm_sync = cc_pair.last_time_perm_sync
|
||||
if last_perm_sync is None:
|
||||
return True
|
||||
|
||||
source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
|
||||
|
||||
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
|
||||
if not source_sync_period:
|
||||
return True
|
||||
|
||||
# If the last sync is greater than the full fetch period, we run the sync
|
||||
next_sync = last_perm_sync + timedelta(seconds=source_sync_period)
|
||||
if datetime.now(timezone.utc) >= next_sync:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_doc_permissions_sync",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
# get all cc pairs that need to be synced
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
if _is_external_doc_permissions_sync_due(cc_pair):
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
tasks_created = try_creating_permissions_sync_task(
|
||||
self.app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not tasks_created:
|
||||
continue
|
||||
|
||||
task_logger.info(f"Doc permissions sync queued: cc_pair={cc_pair_id}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_creating_permissions_sync_task(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Returns None if no syncing is required."""
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
lock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
try:
|
||||
if redis_connector.permissions.fenced:
|
||||
return None
|
||||
|
||||
if redis_connector.delete.fenced:
|
||||
return None
|
||||
|
||||
if redis_connector.prune.fenced:
|
||||
return None
|
||||
|
||||
redis_connector.permissions.generator_clear()
|
||||
redis_connector.permissions.taskset_clear()
|
||||
|
||||
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
|
||||
|
||||
app.send_task(
|
||||
"connector_permission_sync_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorPermissionSyncData(
|
||||
started=None,
|
||||
)
|
||||
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return 1
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="connector_permission_sync_generator_task",
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def connector_permission_sync_generator_task(
|
||||
self: Task,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
Permission sync task that handles document permission syncing for a given connector credential pair
|
||||
This task assumes that the task has already been properly fenced
|
||||
"""
|
||||
|
||||
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
|
||||
doc_permission_sync_ctx_dict["cc_pair_id"] = cc_pair_id
|
||||
doc_permission_sync_ctx_dict["request_id"] = self.request.id
|
||||
doc_permission_sync_ctx.set(doc_permission_sync_ctx_dict)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"Permission sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(
|
||||
f"No connector credential pair found for id: {cc_pair_id}"
|
||||
)
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
|
||||
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if doc_sync_func is None:
|
||||
raise ValueError(f"No doc sync func found for {source_type}")
|
||||
|
||||
logger.info(f"Syncing docs for {source_type}")
|
||||
|
||||
payload = RedisConnectorPermissionSyncData(
|
||||
started=datetime.now(timezone.utc),
|
||||
)
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = redis_connector.permissions.generate_tasks(
|
||||
self.app, lock, document_external_accesses, source_type
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.permissions.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
redis_connector.permissions.generator_complete = tasks_generated
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(f"Failed to run permission sync: cc_pair={cc_pair_id}")
|
||||
|
||||
redis_connector.permissions.generator_clear()
|
||||
redis_connector.permissions.taskset_clear()
|
||||
redis_connector.permissions.set_fence(None)
|
||||
raise e
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="update_external_document_permissions_task",
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
|
||||
bind=True,
|
||||
)
|
||||
def update_external_document_permissions_task(
|
||||
self: Task,
|
||||
tenant_id: str | None,
|
||||
serialized_doc_external_access: dict,
|
||||
source_string: str,
|
||||
) -> bool:
|
||||
document_external_access = DocExternalAccess.from_dict(
|
||||
serialized_doc_external_access
|
||||
)
|
||||
doc_id = document_external_access.doc_id
|
||||
external_access = document_external_access.external_access
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Then we build the update requests to update vespa
|
||||
batch_add_non_web_user_if_not_exists(
|
||||
db_session=db_session,
|
||||
emails=list(external_access.external_user_emails),
|
||||
)
|
||||
upsert_document_external_perms(
|
||||
db_session=db_session,
|
||||
doc_id=doc_id,
|
||||
external_access=external_access,
|
||||
source_type=DocumentSource(source_string),
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Successfully synced postgres document permissions for {doc_id}"
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Error Syncing Document Permissions")
|
||||
return False
|
||||
@@ -1,265 +0,0 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import mark_cc_pair_as_external_group_synced
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair
|
||||
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIOD
|
||||
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
EXTERNAL_GROUPS_UPDATE_MAX_RETRIES = 3
|
||||
|
||||
|
||||
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
|
||||
LIGHT_SOFT_TIME_LIMIT = 105
|
||||
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
|
||||
|
||||
def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if external group sync is due."""
|
||||
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
# skip pruning if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return False
|
||||
|
||||
# If there is not group sync function for the connector, we don't run the sync
|
||||
# This is fine because all sources dont necessarily have a concept of groups
|
||||
if not GROUP_PERMISSIONS_FUNC_MAP.get(cc_pair.connector.source):
|
||||
return False
|
||||
|
||||
# If the last sync is None, it has never been run so we run the sync
|
||||
last_ext_group_sync = cc_pair.last_time_external_group_sync
|
||||
if last_ext_group_sync is None:
|
||||
return True
|
||||
|
||||
source_sync_period = EXTERNAL_GROUP_SYNC_PERIOD
|
||||
|
||||
# If EXTERNAL_GROUP_SYNC_PERIOD is None, we always run the sync.
|
||||
if not source_sync_period:
|
||||
return True
|
||||
|
||||
# If the last sync is greater than the full fetch period, we run the sync
|
||||
next_sync = last_ext_group_sync + timedelta(seconds=source_sync_period)
|
||||
if datetime.now(timezone.utc) >= next_sync:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_external_group_sync",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
if _is_external_group_sync_due(cc_pair):
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
tasks_created = try_creating_permissions_sync_task(
|
||||
self.app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not tasks_created:
|
||||
continue
|
||||
|
||||
task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_creating_permissions_sync_task(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Returns None if no syncing is required."""
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
lock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_external_group_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Dont kick off a new sync if the previous one is still running
|
||||
if redis_connector.external_group_sync.fenced:
|
||||
return None
|
||||
|
||||
redis_connector.external_group_sync.generator_clear()
|
||||
redis_connector.external_group_sync.taskset_clear()
|
||||
|
||||
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
|
||||
|
||||
_ = app.send_task(
|
||||
"connector_external_group_sync_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.HIGH,
|
||||
)
|
||||
# set a basic fence to start
|
||||
redis_connector.external_group_sync.set_fence(True)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return 1
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="connector_external_group_sync_generator_task",
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def connector_external_group_sync_generator_task(
|
||||
self: Task,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
Permission sync task that handles document permission syncing for a given connector credential pair
|
||||
This task assumes that the task has already been properly fenced
|
||||
"""
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(
|
||||
f"No connector credential pair found for id: {cc_pair_id}"
|
||||
)
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
|
||||
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if ext_group_sync_func is None:
|
||||
raise ValueError(f"No external group sync func found for {source_type}")
|
||||
|
||||
logger.info(f"Syncing docs for {source_type}")
|
||||
|
||||
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
|
||||
|
||||
logger.info(
|
||||
f"Syncing {len(external_user_groups)} external user groups for {source_type}"
|
||||
)
|
||||
|
||||
replace_user__ext_group_for_cc_pair(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
group_defs=external_user_groups,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
logger.info(
|
||||
f"Synced {len(external_user_groups)} external user groups for {source_type}"
|
||||
)
|
||||
|
||||
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run external group sync: cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
redis_connector.external_group_sync.generator_clear()
|
||||
redis_connector.external_group_sync.taskset_clear()
|
||||
raise e
|
||||
finally:
|
||||
# we always want to clear the fence after the task is done or failed so it doesn't get stuck
|
||||
redis_connector.external_group_sync.set_fence(False)
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
@@ -1,651 +0,0 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
|
||||
import redis
|
||||
import sentry_sdk
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.indexing.job_client import SimpleJobClient
|
||||
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.enums import IndexModelStatus
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndexingFenceData
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class RunIndexingCallback(RunIndexingCallbackInterface):
|
||||
def __init__(
|
||||
self,
|
||||
stop_key: str,
|
||||
generator_progress_key: str,
|
||||
redis_lock: redis.lock.Lock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.redis_lock: redis.lock.Lock = redis_lock
|
||||
self.stop_key: str = stop_key
|
||||
self.generator_progress_key: str = generator_progress_key
|
||||
self.redis_client = redis_client
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_client.exists(self.stop_key):
|
||||
return True
|
||||
return False
|
||||
|
||||
def progress(self, amount: int) -> None:
|
||||
self.redis_lock.reacquire()
|
||||
self.redis_client.incrby(self.generator_progress_key, amount)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_indexing",
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
tasks_created = 0
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
old_search_settings = check_index_swap(db_session=db_session)
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
if current_search_settings.provider_type is None and not MULTI_TENANT:
|
||||
if old_search_settings:
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=current_search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
# only warm up if search settings were changed
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
search_settings = [primary_search_settings]
|
||||
|
||||
# Check for secondary search settings
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings is not None:
|
||||
# If secondary settings exist, add them to the list
|
||||
search_settings.append(secondary_search_settings)
|
||||
|
||||
for search_settings_instance in search_settings:
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings_instance.id
|
||||
)
|
||||
if redis_connector_index.fenced:
|
||||
continue
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id, db_session
|
||||
)
|
||||
if not cc_pair:
|
||||
continue
|
||||
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
if not _should_index(
|
||||
cc_pair=cc_pair,
|
||||
last_index=last_attempt,
|
||||
search_settings_instance=search_settings_instance,
|
||||
secondary_index_building=len(search_settings) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
# using a task queue and only allowing one task per cc_pair/search_setting
|
||||
# prevents us from starving out certain attempts
|
||||
attempt_id = try_creating_indexing_task(
|
||||
self.app,
|
||||
cc_pair,
|
||||
search_settings_instance,
|
||||
False,
|
||||
db_session,
|
||||
r,
|
||||
tenant_id,
|
||||
)
|
||||
if attempt_id:
|
||||
task_logger.info(
|
||||
f"Indexing queued: index_attempt={attempt_id} "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings_instance.id} "
|
||||
)
|
||||
tasks_created += 1
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return tasks_created
|
||||
|
||||
|
||||
def _should_index(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
last_index: IndexAttempt | None,
|
||||
search_settings_instance: SearchSettings,
|
||||
secondary_index_building: bool,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
"""Checks various global settings and past indexing attempts to determine if
|
||||
we should try to start indexing the cc pair / search setting combination.
|
||||
|
||||
Note that tactical checks such as preventing overlap with a currently running task
|
||||
are not handled here.
|
||||
|
||||
Return True if we should try to index, False if not.
|
||||
"""
|
||||
connector = cc_pair.connector
|
||||
|
||||
# uncomment for debugging
|
||||
# task_logger.info(f"_should_index: "
|
||||
# f"cc_pair={cc_pair.id} "
|
||||
# f"connector={cc_pair.connector_id} "
|
||||
# f"refresh_freq={connector.refresh_freq}")
|
||||
|
||||
# don't kick off indexing for `NOT_APPLICABLE` sources
|
||||
if connector.source == DocumentSource.NOT_APPLICABLE:
|
||||
return False
|
||||
|
||||
# User can still manually create single indexing attempts via the UI for the
|
||||
# currently in use index
|
||||
if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
if (
|
||||
search_settings_instance.status == IndexModelStatus.PRESENT
|
||||
and secondary_index_building
|
||||
):
|
||||
return False
|
||||
|
||||
# When switching over models, always index at least once
|
||||
if search_settings_instance.status == IndexModelStatus.FUTURE:
|
||||
if last_index:
|
||||
# No new index if the last index attempt succeeded
|
||||
# Once is enough. The model will never be able to swap otherwise.
|
||||
if last_index.status == IndexingStatus.SUCCESS:
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is waiting to start
|
||||
if last_index.status == IndexingStatus.NOT_STARTED:
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is running
|
||||
if last_index.status == IndexingStatus.IN_PROGRESS:
|
||||
return False
|
||||
else:
|
||||
if (
|
||||
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
|
||||
): # Ingestion API
|
||||
return False
|
||||
return True
|
||||
|
||||
# If the connector is paused or is the ingestion API, don't index
|
||||
# NOTE: during an embedding model switch over, the following logic
|
||||
# is bypassed by the above check for a future model
|
||||
if (
|
||||
not cc_pair.status.is_active()
|
||||
or connector.id == 0
|
||||
or connector.source == DocumentSource.INGESTION_API
|
||||
):
|
||||
return False
|
||||
|
||||
# if no attempt has ever occurred, we should index regardless of refresh_freq
|
||||
if not last_index:
|
||||
return True
|
||||
|
||||
if connector.refresh_freq is None:
|
||||
return False
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
time_since_index = current_db_time - last_index.time_updated
|
||||
if time_since_index.total_seconds() < connector.refresh_freq:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_indexing_task(
|
||||
celery_app: Celery,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
search_settings: SearchSettings,
|
||||
reindex: bool,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Checks for any conditions that should block the indexing task from being
|
||||
created, then creates the task.
|
||||
|
||||
Does not check for scheduling related conditions as this function
|
||||
is used to trigger indexing immediately.
|
||||
"""
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
# we need to serialize any attempt to trigger indexing since it can be triggered
|
||||
# either via celery beat or manually (API call)
|
||||
lock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
try:
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair.id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
|
||||
# skip if already indexing
|
||||
if redis_connector_index.fenced:
|
||||
return None
|
||||
|
||||
# skip indexing if the cc_pair is deleting
|
||||
if redis_connector.delete.fenced:
|
||||
return None
|
||||
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# add a long running generator task to the queue
|
||||
redis_connector_index.generator_clear()
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorIndexingFenceData(
|
||||
index_attempt_id=None,
|
||||
started=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
celery_task_id=None,
|
||||
)
|
||||
|
||||
redis_connector_index.set_fence(payload)
|
||||
|
||||
# create the index attempt for tracking purposes
|
||||
# code elsewhere checks for index attempts without an associated redis key
|
||||
# and cleans them up
|
||||
# therefore we must create the attempt and the task after the fence goes up
|
||||
index_attempt_id = create_index_attempt(
|
||||
cc_pair.id,
|
||||
search_settings.id,
|
||||
from_beginning=reindex,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
custom_task_id = redis_connector_index.generate_generator_task_id()
|
||||
|
||||
result = celery_app.send_task(
|
||||
"connector_indexing_proxy_task",
|
||||
kwargs=dict(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings.id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_INDEXING,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
|
||||
|
||||
# now fill out the fence with the rest of the data
|
||||
payload.index_attempt_id = index_attempt_id
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector_index.set_fence(payload)
|
||||
|
||||
except Exception:
|
||||
redis_connector_index.set_fence(payload)
|
||||
task_logger.exception(
|
||||
f"Unexpected exception: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return index_attempt_id
|
||||
|
||||
|
||||
@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True)
|
||||
def connector_indexing_proxy_task(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
|
||||
task_logger.info(
|
||||
f"Indexing proxy - starting: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
client = SimpleJobClient()
|
||||
|
||||
job = client.submit(
|
||||
connector_indexing_task,
|
||||
index_attempt_id,
|
||||
cc_pair_id,
|
||||
search_settings_id,
|
||||
tenant_id,
|
||||
global_version.is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
|
||||
if not job:
|
||||
task_logger.info(
|
||||
f"Indexing proxy - spawn failed: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
while True:
|
||||
sleep(10)
|
||||
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if not job.done():
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
task_logger.error(
|
||||
f"Indexing proxy - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"error={job.exception()}"
|
||||
)
|
||||
|
||||
job.release()
|
||||
break
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing proxy - finished: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def connector_indexing_task(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str | None,
|
||||
is_ee: bool,
|
||||
) -> int | None:
|
||||
"""Indexing task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list
|
||||
|
||||
acks_late must be set to False. Otherwise, celery's visibility timeout will
|
||||
cause any task that runs longer than the timeout to be redispatched by the broker.
|
||||
There appears to be no good workaround for this, so we need to handle redispatching
|
||||
manually.
|
||||
|
||||
Returns None if the task did not run (possibly due to a conflict).
|
||||
Otherwise, returns an int >= 0 representing the number of indexed docs.
|
||||
|
||||
NOTE: if an exception is raised out of this task, the primary worker will detect
|
||||
that the task transitioned to a "READY" state but the generator_complete_key doesn't exist.
|
||||
This will cause the primary worker to abort the indexing attempt and clean up.
|
||||
"""
|
||||
|
||||
# Since connector_indexing_proxy_task spawns a new process using this function as
|
||||
# the entrypoint, we init Sentry here.
|
||||
if SENTRY_DSN:
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
traces_sample_rate=0.1,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task starting: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
attempt_found = False
|
||||
n_final_progress: int | None = None
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if redis_connector.delete.fenced:
|
||||
raise RuntimeError(
|
||||
f"Indexing will not start because connector deletion is in progress: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={redis_connector.delete.fence_key}"
|
||||
)
|
||||
|
||||
if redis_connector.stop.fenced:
|
||||
raise RuntimeError(
|
||||
f"Indexing will not start because a connector stop signal was detected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={redis_connector.stop.fence_key}"
|
||||
)
|
||||
|
||||
while True:
|
||||
# wait for the fence to come up
|
||||
if not redis_connector_index.fenced:
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
|
||||
)
|
||||
|
||||
payload = redis_connector_index.payload
|
||||
if not payload:
|
||||
raise ValueError("connector_indexing_task: payload invalid or not found")
|
||||
|
||||
if payload.index_attempt_id is None or payload.celery_task_id is None:
|
||||
logger.info(
|
||||
f"connector_indexing_task - Waiting for fence: fence={redis_connector_index.fence_key}"
|
||||
)
|
||||
sleep(1)
|
||||
continue
|
||||
|
||||
if payload.index_attempt_id != index_attempt_id:
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - id mismatch. Task may be left over from previous run.: "
|
||||
f"task_index_attempt={index_attempt_id} "
|
||||
f"payload_index_attempt={payload.index_attempt_id}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"connector_indexing_task - Fence found, continuing...: fence={redis_connector_index.fence_key}"
|
||||
)
|
||||
break
|
||||
|
||||
lock = r.lock(
|
||||
redis_connector_index.generator_lock_key,
|
||||
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
f"Indexing task already running, exiting...: "
|
||||
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector_index.set_fence(payload)
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if not attempt:
|
||||
raise ValueError(
|
||||
f"Index attempt not found: index_attempt={index_attempt_id}"
|
||||
)
|
||||
attempt_found = True
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
raise ValueError(f"cc_pair not found: cc_pair={cc_pair_id}")
|
||||
|
||||
if not cc_pair.connector:
|
||||
raise ValueError(
|
||||
f"Connector not found: cc_pair={cc_pair_id} connector={cc_pair.connector_id}"
|
||||
)
|
||||
|
||||
if not cc_pair.credential:
|
||||
raise ValueError(
|
||||
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}"
|
||||
)
|
||||
|
||||
# define a callback class
|
||||
callback = RunIndexingCallback(
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector_index.generator_progress_key,
|
||||
lock,
|
||||
r,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task running entrypoint: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
run_indexing_entrypoint(
|
||||
index_attempt_id,
|
||||
tenant_id,
|
||||
cc_pair_id,
|
||||
is_ee,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
# get back the total number of indexed docs and return it
|
||||
n_final_progress = redis_connector_index.get_progress()
|
||||
redis_connector_index.set_generator_complete(HTTPStatus.OK.value)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Indexing spawned task failed: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
if attempt_found:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_failed(index_attempt_id, db_session, failure_reason=str(e))
|
||||
|
||||
raise e
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task finished: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
return n_final_progress
|
||||
@@ -11,10 +11,10 @@ from sqlalchemy import inspect
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import PostgresAdvisoryLocks
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_sqlalchemy_engine # type: ignore
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -23,7 +23,7 @@ from danswer.db.engine import get_session_with_tenant
|
||||
bind=True,
|
||||
base=AbortableTask,
|
||||
)
|
||||
def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
|
||||
def kombu_message_cleanup_task(self: Any) -> int:
|
||||
"""Runs periodically to clean up the kombu_message table"""
|
||||
|
||||
# we will select messages older than this amount to clean up
|
||||
@@ -35,7 +35,7 @@ def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
|
||||
ctx["deleted"] = 0
|
||||
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
|
||||
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# Exit the task if we can't take the advisory lock
|
||||
result = db_session.execute(
|
||||
text("SELECT pg_try_advisory_lock(:id)"),
|
||||
|
||||
@@ -3,77 +3,39 @@ from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
import redis
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import pruning_ctx
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if pruning is due."""
|
||||
|
||||
# skip pruning if no prune frequency is set
|
||||
# pruning can still be forced via the API which will run a pruning task directly
|
||||
if not cc_pair.connector.prune_freq:
|
||||
return False
|
||||
|
||||
# skip pruning if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
# skip pruning if the next scheduled prune time hasn't been reached yet
|
||||
last_pruned = cc_pair.last_pruned
|
||||
if not last_pruned:
|
||||
if not cc_pair.last_successful_index_time:
|
||||
# if we've never indexed, we can't prune
|
||||
return False
|
||||
|
||||
# if never pruned, use the last time the connector indexed successfully
|
||||
last_pruned = cc_pair.last_successful_index_time
|
||||
|
||||
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
|
||||
if datetime.now(timezone.utc) < next_prune:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_pruning",
|
||||
name="check_for_prune_task_2",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
def check_for_prune_task_2() -> None:
|
||||
r = get_redis_client()
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
||||
@@ -85,164 +47,121 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if not cc_pair:
|
||||
continue
|
||||
|
||||
if not _is_pruning_due(cc_pair):
|
||||
continue
|
||||
|
||||
tasks_created = try_creating_prune_generator_task(
|
||||
self.app, cc_pair, db_session, r, tenant_id
|
||||
for cc_pair in cc_pairs:
|
||||
tasks_created = ccpair_pruning_generator_task_creation_helper(
|
||||
cc_pair, db_session, r, lock_beat
|
||||
)
|
||||
if not tasks_created:
|
||||
continue
|
||||
|
||||
task_logger.info(f"Pruning queued: cc_pair={cc_pair.id}")
|
||||
task_logger.info(f"Pruning started: cc_pair_id={cc_pair.id}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
task_logger.exception("Unexpected exception")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_creating_prune_generator_task(
|
||||
celery_app: Celery,
|
||||
def ccpair_pruning_generator_task_creation_helper(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
"""Returns an int if pruning is triggered.
|
||||
The int represents the number of prune tasks generated (in this case, only one
|
||||
because the task is a long running generator task.)
|
||||
Returns None if no pruning is triggered (due to not being needed or
|
||||
other reasons such as simultaneous pruning restrictions.
|
||||
|
||||
Checks for scheduling related conditions, then delegates the rest of the checks to
|
||||
try_creating_prune_generator_task.
|
||||
"""
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
# skip pruning if no prune frequency is set
|
||||
# pruning can still be forced via the API which will run a pruning task directly
|
||||
if not cc_pair.connector.prune_freq:
|
||||
return None
|
||||
|
||||
# skip pruning if the next scheduled prune time hasn't been reached yet
|
||||
last_pruned = cc_pair.last_pruned
|
||||
if not last_pruned:
|
||||
# if never pruned, use the connector time created as the last_pruned time
|
||||
last_pruned = cc_pair.connector.time_created
|
||||
|
||||
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
|
||||
if datetime.now(timezone.utc) < next_prune:
|
||||
return None
|
||||
|
||||
return try_creating_prune_generator_task(cc_pair, db_session, r)
|
||||
|
||||
|
||||
def try_creating_prune_generator_task(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Checks for any conditions that should block the pruning generator task from being
|
||||
created, then creates the task.
|
||||
|
||||
Does not check for scheduling related conditions as this function
|
||||
is used to trigger prunes immediately, e.g. via the web ui.
|
||||
is used to trigger prunes immediately.
|
||||
"""
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair.id)
|
||||
|
||||
if not ALLOW_SIMULTANEOUS_PRUNING:
|
||||
count = redis_connector.prune.get_active_task_count()
|
||||
if count > 0:
|
||||
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
return None
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
rcp = RedisConnectorPruning(cc_pair.id)
|
||||
|
||||
# we need to serialize starting pruning since it can be triggered either via
|
||||
# celery beat or manually (API call)
|
||||
lock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_prune_generator_task",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
# skip pruning if already pruning
|
||||
if r.exists(rcp.fence_key):
|
||||
return None
|
||||
|
||||
# skip pruning if the cc_pair is deleting
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# add a long running generator task to the queue
|
||||
r.delete(rcp.generator_complete_key)
|
||||
r.delete(rcp.taskset_key)
|
||||
|
||||
custom_task_id = f"{rcp.generator_task_id_prefix}_{uuid4()}"
|
||||
|
||||
celery_app.send_task(
|
||||
"connector_pruning_generator_task",
|
||||
kwargs=dict(
|
||||
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_PRUNING,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
try:
|
||||
# skip pruning if already pruning
|
||||
if redis_connector.prune.fenced:
|
||||
return None
|
||||
|
||||
# skip pruning if the cc_pair is deleting
|
||||
if redis_connector.delete.fenced:
|
||||
return None
|
||||
|
||||
# skip pruning if doc permissions sync is running
|
||||
if redis_connector.permissions.fenced:
|
||||
return None
|
||||
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# add a long running generator task to the queue
|
||||
redis_connector.prune.generator_clear()
|
||||
redis_connector.prune.taskset_clear()
|
||||
|
||||
custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}"
|
||||
|
||||
celery_app.send_task(
|
||||
"connector_pruning_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair.id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_PRUNING,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
redis_connector.prune.set_fence(True)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}")
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rcp.fence_key, 1)
|
||||
return 1
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="connector_pruning_generator_task",
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def connector_pruning_generator_task(
|
||||
self: Task,
|
||||
cc_pair_id: int,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
@shared_task(name="connector_pruning_generator_task", soft_time_limit=JOB_TIMEOUT)
|
||||
def connector_pruning_generator_task(connector_id: int, credential_id: int) -> None:
|
||||
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list"""
|
||||
|
||||
pruning_ctx_dict = pruning_ctx.get()
|
||||
pruning_ctx_dict["cc_pair_id"] = cc_pair_id
|
||||
pruning_ctx_dict["request_id"] = self.request.id
|
||||
pruning_ctx.set(pruning_ctx_dict)
|
||||
r = get_redis_client()
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
|
||||
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"Pruning task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
try:
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
@@ -251,27 +170,27 @@ def connector_pruning_generator_task(
|
||||
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"cc_pair not found for {connector_id} {credential_id}"
|
||||
f"ccpair not found for {connector_id} {credential_id}"
|
||||
)
|
||||
return
|
||||
|
||||
rcp = RedisConnectorPruning(cc_pair.id)
|
||||
|
||||
# Define the callback function
|
||||
def redis_increment_callback(amount: int) -> None:
|
||||
r.incrby(rcp.generator_progress_key, amount)
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session,
|
||||
cc_pair.connector.source,
|
||||
InputType.SLIM_RETRIEVAL,
|
||||
InputType.PRUNE,
|
||||
cc_pair.connector.connector_specific_config,
|
||||
cc_pair.credential,
|
||||
)
|
||||
|
||||
callback = RunIndexingCallback(
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector.prune.generator_progress_key,
|
||||
lock,
|
||||
r,
|
||||
)
|
||||
# a list of docs in the source
|
||||
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
|
||||
runnable_connector, callback
|
||||
runnable_connector, redis_increment_callback
|
||||
)
|
||||
|
||||
# a list of docs in our local index
|
||||
@@ -289,35 +208,32 @@ def connector_pruning_generator_task(
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning set collected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"cc_pair_id={cc_pair.id} "
|
||||
f"docs_to_remove={len(doc_ids_to_remove)} "
|
||||
f"doc_source={cc_pair.connector.source}"
|
||||
)
|
||||
|
||||
rcp.documents_to_prune = set(doc_ids_to_remove)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.prune.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = redis_connector.prune.generate_tasks(
|
||||
set(doc_ids_to_remove), self.app, db_session, None
|
||||
f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
||||
)
|
||||
tasks_generated = rcp.generate_tasks(celery_app, db_session, r, None)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.prune.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
f"RedisConnectorPruning.generate_tasks finished. "
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
redis_connector.prune.generator_complete = tasks_generated
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
|
||||
)
|
||||
r.set(rcp.generator_complete_key, tasks_generated)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run pruning for connector id {connector_id}."
|
||||
)
|
||||
|
||||
redis_connector.prune.generator_clear()
|
||||
redis_connector.prune.taskset_clear()
|
||||
redis_connector.prune.set_fence(False)
|
||||
raise e
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
r.delete(rcp.generator_progress_key)
|
||||
r.delete(rcp.taskset_key)
|
||||
r.delete(rcp.fence_key)
|
||||
raise e
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
import httpx
|
||||
from tenacity import retry
|
||||
from tenacity import retry_if_exception_type
|
||||
from tenacity import stop_after_delay
|
||||
from tenacity import wait_random_exponential
|
||||
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import VespaDocumentFields
|
||||
|
||||
|
||||
class RetryDocumentIndex:
|
||||
"""A wrapper class to help with specific retries against Vespa involving
|
||||
read timeouts.
|
||||
|
||||
wait_random_exponential implements full jitter as per this article:
|
||||
https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/"""
|
||||
|
||||
MAX_WAIT = 30
|
||||
|
||||
# STOP_AFTER + MAX_WAIT should be slightly less (5?) than the celery soft_time_limit
|
||||
STOP_AFTER = 70
|
||||
|
||||
def __init__(self, index: DocumentIndex):
|
||||
self.index: DocumentIndex = index
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(httpx.ReadTimeout),
|
||||
wait=wait_random_exponential(multiplier=1, max=MAX_WAIT),
|
||||
stop=stop_after_delay(STOP_AFTER),
|
||||
)
|
||||
def delete_single(self, doc_id: str) -> int:
|
||||
return self.index.delete_single(doc_id)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(httpx.ReadTimeout),
|
||||
wait=wait_random_exponential(multiplier=1, max=MAX_WAIT),
|
||||
stop=stop_after_delay(STOP_AFTER),
|
||||
)
|
||||
def update_single(self, doc_id: str, fields: VespaDocumentFields) -> int:
|
||||
return self.index.update_single(doc_id, fields)
|
||||
@@ -1,48 +1,32 @@
|
||||
from http import HTTPStatus
|
||||
|
||||
import httpx
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from tenacity import RetryError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document
|
||||
from danswer.db.document import get_document_connector_count
|
||||
from danswer.db.document import mark_document_as_modified
|
||||
from danswer.db.document import mark_document_as_synced
|
||||
from danswer.db.document_set import fetch_document_sets_for_document
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import VespaDocumentFields
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES = 3
|
||||
|
||||
|
||||
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
|
||||
LIGHT_SOFT_TIME_LIMIT = 105
|
||||
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="document_by_cc_pair_cleanup_task",
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES,
|
||||
bind=True,
|
||||
soft_time_limit=45,
|
||||
time_limit=60,
|
||||
max_retries=3,
|
||||
)
|
||||
def document_by_cc_pair_cleanup_task(
|
||||
self: Task,
|
||||
document_id: str,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
tenant_id: str | None,
|
||||
self: Task, document_id: str, connector_id: int, credential_id: int
|
||||
) -> bool:
|
||||
"""A lightweight subtask used to clean up document to cc pair relationships.
|
||||
Created by connection deletion and connector pruning parent tasks."""
|
||||
@@ -59,34 +43,25 @@ def document_by_cc_pair_cleanup_task(
|
||||
connector / credential pair from the access list
|
||||
(6) delete all relevant entries from postgres
|
||||
"""
|
||||
task_logger.debug(f"Task start: tenant={tenant_id} doc={document_id}")
|
||||
task_logger.info(f"document_id={document_id}")
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
action = "skip"
|
||||
chunks_affected = 0
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
retry_index = RetryDocumentIndex(doc_index)
|
||||
|
||||
count = get_document_connector_count(db_session, document_id)
|
||||
if count == 1:
|
||||
# count == 1 means this is the only remaining cc_pair reference to the doc
|
||||
# delete it from vespa and the db
|
||||
action = "delete"
|
||||
|
||||
chunks_affected = retry_index.delete_single(document_id)
|
||||
document_index.delete(doc_ids=[document_id])
|
||||
delete_documents_complete__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=[document_id],
|
||||
)
|
||||
elif count > 1:
|
||||
action = "update"
|
||||
|
||||
# count > 1 means the document still has cc_pair references
|
||||
doc = get_document(document_id, db_session)
|
||||
if not doc:
|
||||
@@ -109,7 +84,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
chunks_affected = retry_index.update_single(document_id, fields=fields)
|
||||
document_index.update_single(document_id, fields=fields)
|
||||
|
||||
# there are still other cc_pair references to the doc, so just resync to Vespa
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
@@ -126,59 +101,13 @@ def document_by_cc_pair_cleanup_task(
|
||||
pass
|
||||
|
||||
db_session.commit()
|
||||
|
||||
task_logger.info(
|
||||
f"tenant={tenant_id} "
|
||||
f"doc={document_id} "
|
||||
f"action={action} "
|
||||
f"refcount={count} "
|
||||
f"chunks={chunks_affected}"
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
f"SoftTimeLimitExceeded exception. tenant={tenant_id} doc={document_id}"
|
||||
)
|
||||
return False
|
||||
except Exception as ex:
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(
|
||||
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
|
||||
)
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
|
||||
except Exception as e:
|
||||
task_logger.exception("Unexpected exception")
|
||||
|
||||
# only set the inner exception if it is of type Exception
|
||||
e_temp = ex.last_attempt.exception()
|
||||
if isinstance(e_temp, Exception):
|
||||
e = e_temp
|
||||
else:
|
||||
e = ex
|
||||
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
if e.response.status_code == HTTPStatus.BAD_REQUEST:
|
||||
task_logger.exception(
|
||||
f"Non-retryable HTTPStatusError: "
|
||||
f"tenant={tenant_id} "
|
||||
f"doc={document_id} "
|
||||
f"status={e.response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
task_logger.exception(
|
||||
f"Unexpected exception: tenant={tenant_id} doc={document_id}"
|
||||
)
|
||||
|
||||
if self.request.retries < DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES:
|
||||
# Still retrying. Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
else:
|
||||
# This is the last attempt! mark the document as dirty in the db so that it
|
||||
# eventually gets fixed out of band via stale document reconciliation
|
||||
task_logger.warning(
|
||||
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
|
||||
f"tenant={tenant_id} doc={document_id}"
|
||||
)
|
||||
with get_session_with_tenant(tenant_id):
|
||||
mark_document_as_modified(document_id, db_session)
|
||||
return False
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
|
||||
return True
|
||||
|
||||
@@ -1,33 +1,25 @@
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
import redis
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.result import AsyncResult
|
||||
from celery.states import READY_STATES
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
from tenacity import RetryError
|
||||
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import celery_get_queue_length
|
||||
from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from danswer.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from danswer.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from danswer.db.connector import mark_ccpair_as_pruned
|
||||
from danswer.db.connector_credential_pair import add_deletion_failure_message
|
||||
from danswer.db.connector_credential_pair import (
|
||||
@@ -37,7 +29,6 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair_f
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.document import count_documents_by_needs_sync
|
||||
from danswer.db.document import get_document
|
||||
from danswer.db.document import get_document_ids_for_connector_credential_pair
|
||||
from danswer.db.document import mark_document_as_synced
|
||||
from danswer.db.document_set import delete_document_set
|
||||
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
@@ -45,39 +36,20 @@ from danswer.db.document_set import fetch_document_sets
|
||||
from danswer.db.document_set import fetch_document_sets_for_document
|
||||
from danswer.db.document_set import get_document_set_by_id
|
||||
from danswer.db.document_set import mark_document_set_as_synced
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import delete_index_attempts
|
||||
from danswer.db.index_attempt import get_all_index_attempts_by_status
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import UserGroup
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import VespaDocumentFields
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from danswer.redis.redis_connector_doc_perm_sync import (
|
||||
RedisConnectorPermissionSyncData,
|
||||
)
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndex
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from danswer.redis.redis_document_set import RedisDocumentSet
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.redis.redis_usergroup import RedisUserGroup
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from danswer.utils.variable_functionality import noop_fallback
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# celery auto associates tasks created inside another task,
|
||||
# which bloats the result metadata considerably. trail=False prevents this.
|
||||
@@ -85,13 +57,12 @@ logger = setup_logger()
|
||||
name="check_for_vespa_sync_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_vespa_sync_task() -> None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r = get_redis_client()
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
|
||||
@@ -103,72 +74,47 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_stale_document_sync_tasks(
|
||||
self.app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
try_generate_stale_document_sync_tasks(db_session, r, lock_beat)
|
||||
|
||||
# region document set scan
|
||||
document_set_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# check if any document sets are not synced
|
||||
document_set_info = fetch_document_sets(
|
||||
user_id=None, db_session=db_session, include_outdated=True
|
||||
)
|
||||
|
||||
for document_set, _ in document_set_info:
|
||||
document_set_ids.append(document_set.id)
|
||||
|
||||
for document_set_id in document_set_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_document_set_sync_tasks(
|
||||
self.app, document_set_id, db_session, r, lock_beat, tenant_id
|
||||
document_set, db_session, r, lock_beat
|
||||
)
|
||||
# endregion
|
||||
|
||||
# check if any user groups are not synced
|
||||
if global_version.is_ee_version():
|
||||
# check if any user groups are not synced
|
||||
try:
|
||||
fetch_user_groups = fetch_versioned_implementation(
|
||||
"danswer.db.user_group", "fetch_user_groups"
|
||||
)
|
||||
|
||||
user_groups = fetch_user_groups(
|
||||
db_session=db_session, only_up_to_date=False
|
||||
)
|
||||
for usergroup in user_groups:
|
||||
try_generate_user_group_sync_tasks(
|
||||
usergroup, db_session, r, lock_beat
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
# Always exceptions on the MIT version, which is expected
|
||||
# We shouldn't actually get here if the ee version check works
|
||||
pass
|
||||
else:
|
||||
usergroup_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
user_groups = fetch_user_groups(
|
||||
db_session=db_session, only_up_to_date=False
|
||||
)
|
||||
|
||||
for usergroup in user_groups:
|
||||
usergroup_ids.append(usergroup.id)
|
||||
|
||||
for usergroup_id in usergroup_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_user_group_sync_tasks(
|
||||
self.app, usergroup_id, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
task_logger.exception("Unexpected exception")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_generate_stale_document_sync_tasks(
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
db_session: Session, r: Redis, lock_beat: redis.lock.Lock
|
||||
) -> int | None:
|
||||
# the fence is up, do nothing
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
@@ -192,10 +138,8 @@ def try_generate_stale_document_sync_tasks(
|
||||
total_tasks_generated = 0
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
|
||||
tasks_generated = rc.generate_tasks(
|
||||
celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
rc = RedisConnectorCredentialPair(cc_pair.id)
|
||||
tasks_generated = rc.generate_tasks(celery_app, db_session, r, lock_beat)
|
||||
|
||||
if tasks_generated is None:
|
||||
continue
|
||||
@@ -219,27 +163,19 @@ def try_generate_stale_document_sync_tasks(
|
||||
|
||||
|
||||
def try_generate_document_set_sync_tasks(
|
||||
celery_app: Celery,
|
||||
document_set_id: int,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
document_set: DocumentSet, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rds = RedisDocumentSet(tenant_id, document_set_id)
|
||||
rds = RedisDocumentSet(document_set.id)
|
||||
|
||||
# don't generate document set sync tasks if tasks are still pending
|
||||
if rds.fenced:
|
||||
if r.exists(rds.fence_key):
|
||||
return None
|
||||
|
||||
# don't generate sync tasks if we're up to date
|
||||
# race condition with the monitor/cleanup function if we use a cached result!
|
||||
document_set = get_document_set_by_id(db_session, document_set_id)
|
||||
if not document_set:
|
||||
return None
|
||||
|
||||
db_session.refresh(document_set)
|
||||
if document_set.is_up_to_date:
|
||||
return None
|
||||
|
||||
@@ -251,9 +187,7 @@ def try_generate_document_set_sync_tasks(
|
||||
)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
tasks_generated = rds.generate_tasks(
|
||||
celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
tasks_generated = rds.generate_tasks(celery_app, db_session, r, lock_beat)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
@@ -269,34 +203,23 @@ def try_generate_document_set_sync_tasks(
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
rds.set_fence(tasks_generated)
|
||||
r.set(rds.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
||||
|
||||
def try_generate_user_group_sync_tasks(
|
||||
celery_app: Celery,
|
||||
usergroup_id: int,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
usergroup: UserGroup, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rug = RedisUserGroup(tenant_id, usergroup_id)
|
||||
if rug.fenced:
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
rug = RedisUserGroup(usergroup.id)
|
||||
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
if r.exists(rug.fence_key):
|
||||
return None
|
||||
|
||||
# race condition with the monitor/cleanup function if we use a cached result!
|
||||
fetch_user_group = fetch_versioned_implementation(
|
||||
"danswer.db.user_group", "fetch_user_group"
|
||||
)
|
||||
|
||||
usergroup = fetch_user_group(db_session, usergroup_id)
|
||||
if not usergroup:
|
||||
return None
|
||||
|
||||
db_session.refresh(usergroup)
|
||||
if usergroup.is_up_to_date:
|
||||
return None
|
||||
|
||||
@@ -307,9 +230,7 @@ def try_generate_user_group_sync_tasks(
|
||||
task_logger.info(
|
||||
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
|
||||
)
|
||||
tasks_generated = rug.generate_tasks(
|
||||
celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
tasks_generated = rug.generate_tasks(celery_app, db_session, r, lock_beat)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
@@ -325,7 +246,7 @@ def try_generate_user_group_sync_tasks(
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
rug.set_fence(tasks_generated)
|
||||
r.set(rug.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
||||
|
||||
@@ -351,28 +272,29 @@ def monitor_connector_taskset(r: Redis) -> None:
|
||||
|
||||
|
||||
def monitor_document_set_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
|
||||
if document_set_id_str is None:
|
||||
document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key)
|
||||
if document_set_id is None:
|
||||
task_logger.warning(f"could not parse document set id from {fence_key}")
|
||||
return
|
||||
|
||||
document_set_id = int(document_set_id_str)
|
||||
rds = RedisDocumentSet(document_set_id)
|
||||
|
||||
rds = RedisDocumentSet(tenant_id, document_set_id)
|
||||
if not rds.fenced:
|
||||
fence_value = r.get(rds.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
initial_count = rds.payload
|
||||
if initial_count is None:
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rds.taskset_key))
|
||||
task_logger.info(
|
||||
f"Document set sync progress: document_set={document_set_id} "
|
||||
f"remaining={count} initial={initial_count}"
|
||||
f"Document set sync progress: document_set_id={document_set_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
return
|
||||
@@ -386,73 +308,58 @@ def monitor_document_set_taskset(
|
||||
# if there are no connectors, then delete the document set.
|
||||
delete_document_set(document_set_row=document_set, db_session=db_session)
|
||||
task_logger.info(
|
||||
f"Successfully deleted document set: document_set={document_set_id}"
|
||||
f"Successfully deleted document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
else:
|
||||
mark_document_set_as_synced(document_set_id, db_session)
|
||||
task_logger.info(
|
||||
f"Successfully synced document set: document_set={document_set_id}"
|
||||
f"Successfully synced document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
|
||||
rds.reset()
|
||||
r.delete(rds.taskset_key)
|
||||
r.delete(rds.fence_key)
|
||||
|
||||
|
||||
def monitor_connector_deletion_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis
|
||||
) -> None:
|
||||
def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id is None:
|
||||
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
fence_data = redis_connector.delete.payload
|
||||
if not fence_data:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - fence payload invalid: cc_pair={cc_pair_id}"
|
||||
)
|
||||
fence_value = r.get(rcd.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
if fence_data.num_tasks is None:
|
||||
# the fence is setting up but isn't ready yet
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
remaining = redis_connector.delete.get_remaining()
|
||||
count = cast(int, r.scard(rcd.taskset_key))
|
||||
task_logger.info(
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
|
||||
f"Connector deletion progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if remaining > 0:
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
|
||||
f"monitor_connector_deletion_taskset - cc_pair_id not found: cc_pair_id={cc_pair_id}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
doc_ids = get_document_ids_for_connector_credential_pair(
|
||||
db_session, cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
if len(doc_ids) > 0:
|
||||
# if this happens, documents somehow got added while deletion was in progress. Likely a bug
|
||||
# gating off pruning and indexing work before deletion starts
|
||||
task_logger.warning(
|
||||
f"Connector deletion - documents still found after taskset completion: "
|
||||
f"cc_pair={cc_pair_id} num={len(doc_ids)}"
|
||||
)
|
||||
|
||||
# clean up the rest of the related Postgres entities
|
||||
# index attempts
|
||||
delete_index_attempts(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
)
|
||||
|
||||
# document sets
|
||||
@@ -469,7 +376,7 @@ def monitor_connector_deletion_taskset(
|
||||
noop_fallback,
|
||||
)
|
||||
cleanup_user_groups(
|
||||
cc_pair_id=cc_pair_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
@@ -486,210 +393,89 @@ def monitor_connector_deletion_taskset(
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
task_logger.info(
|
||||
"Connector deletion - Found no credentials left for connector, deleting connector"
|
||||
"Found no credentials left for connector, deleting connector"
|
||||
)
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
stack_trace = traceback.format_exc()
|
||||
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
|
||||
add_deletion_failure_message(db_session, cc_pair_id, error_message)
|
||||
add_deletion_failure_message(db_session, cc_pair.id, error_message)
|
||||
task_logger.exception(
|
||||
f"Connector deletion exceptioned: "
|
||||
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
|
||||
f"Failed to run connector_deletion. "
|
||||
f"cc_pair_id={cc_pair_id} connector_id={cc_pair.connector_id} credential_id={cc_pair.credential_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
f"Connector deletion succeeded: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector={cc_pair.connector_id} "
|
||||
f"credential={cc_pair.credential_id} "
|
||||
f"docs_deleted={fence_data.num_tasks}"
|
||||
f"Successfully deleted cc_pair: "
|
||||
f"cc_pair_id={cc_pair_id} "
|
||||
f"connector_id={cc_pair.connector_id} "
|
||||
f"credential_id={cc_pair.credential_id} "
|
||||
f"docs_deleted={initial_count}"
|
||||
)
|
||||
|
||||
redis_connector.delete.taskset_clear()
|
||||
redis_connector.delete.set_fence(None)
|
||||
r.delete(rcd.taskset_key)
|
||||
r.delete(rcd.fence_key)
|
||||
|
||||
|
||||
def monitor_ccpair_pruning_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
cc_pair_id = RedisConnectorPruning.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
|
||||
f"monitor_connector_pruning_taskset: could not parse cc_pair_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
rcp = RedisConnectorPruning(cc_pair_id)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if not redis_connector.prune.fenced:
|
||||
fence_value = r.get(rcp.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
initial = redis_connector.prune.generator_complete
|
||||
if initial is None:
|
||||
generator_value = r.get(rcp.generator_complete_key)
|
||||
if generator_value is None:
|
||||
return
|
||||
|
||||
remaining = redis_connector.prune.get_remaining()
|
||||
try:
|
||||
initial_count = int(cast(int, generator_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rcp.taskset_key))
|
||||
task_logger.info(
|
||||
f"Connector pruning progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
|
||||
f"Connector pruning progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if remaining > 0:
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
|
||||
mark_ccpair_as_pruned(cc_pair_id, db_session)
|
||||
task_logger.info(
|
||||
f"Successfully pruned connector credential pair. cc_pair={cc_pair_id}"
|
||||
f"Successfully pruned connector credential pair. cc_pair_id={cc_pair_id}"
|
||||
)
|
||||
|
||||
redis_connector.prune.taskset_clear()
|
||||
redis_connector.prune.generator_clear()
|
||||
redis_connector.prune.set_fence(False)
|
||||
r.delete(rcp.taskset_key)
|
||||
r.delete(rcp.generator_progress_key)
|
||||
r.delete(rcp.generator_complete_key)
|
||||
r.delete(rcp.fence_key)
|
||||
|
||||
|
||||
def monitor_ccpair_permissions_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if not redis_connector.permissions.fenced:
|
||||
return
|
||||
|
||||
initial = redis_connector.permissions.generator_complete
|
||||
if initial is None:
|
||||
return
|
||||
|
||||
remaining = redis_connector.permissions.get_remaining()
|
||||
task_logger.info(
|
||||
f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
|
||||
)
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
payload: RedisConnectorPermissionSyncData | None = (
|
||||
redis_connector.permissions.payload
|
||||
)
|
||||
start_time: datetime | None = payload.started if payload else None
|
||||
|
||||
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
|
||||
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
|
||||
|
||||
redis_connector.permissions.taskset_clear()
|
||||
redis_connector.permissions.generator_clear()
|
||||
redis_connector.permissions.set_fence(None)
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if composite_id is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_indexing_taskset: could not parse composite_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
# parse out metadata and initialize the helper class with it
|
||||
parts = composite_id.split("/")
|
||||
if len(parts) != 2:
|
||||
return
|
||||
|
||||
cc_pair_id = int(parts[0])
|
||||
search_settings_id = int(parts[1])
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
if not redis_connector_index.fenced:
|
||||
return
|
||||
|
||||
payload = redis_connector_index.payload
|
||||
if not payload:
|
||||
return
|
||||
|
||||
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
|
||||
|
||||
progress = redis_connector_index.get_progress()
|
||||
if progress is not None:
|
||||
task_logger.info(
|
||||
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
|
||||
f"search_settings_id={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
|
||||
if payload.index_attempt_id is None or payload.celery_task_id is None:
|
||||
# the task is still setting up
|
||||
return
|
||||
|
||||
# Read result state BEFORE generator_complete_key to avoid a race condition
|
||||
# never use any blocking methods on the result from inside a task!
|
||||
result: AsyncResult = AsyncResult(payload.celery_task_id)
|
||||
result_state = result.state
|
||||
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int is None:
|
||||
if result_state in READY_STATES:
|
||||
# IF the task state is READY, THEN generator_complete should be set
|
||||
# if it isn't, then the worker crashed
|
||||
task_logger.info(
|
||||
f"Connector indexing aborted: "
|
||||
f"cc_pair_id={cc_pair_id} "
|
||||
f"search_settings_id={search_settings_id} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
|
||||
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
|
||||
if index_attempt:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=payload.index_attempt_id,
|
||||
db_session=db_session,
|
||||
failure_reason="Connector indexing aborted or exceptioned.",
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
status_enum = HTTPStatus(status_int)
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing finished: cc_pair_id={cc_pair_id} "
|
||||
f"search_settings_id={search_settings_id} "
|
||||
f"status={status_enum.name} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
|
||||
|
||||
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
|
||||
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
@shared_task(name="monitor_vespa_sync", soft_time_limit=300)
|
||||
def monitor_vespa_sync() -> None:
|
||||
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||
It scans for fence values and then gets the counts of any associated tasksets.
|
||||
If the count is 0, that means all tasks finished and we should clean up.
|
||||
|
||||
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
|
||||
do anything too expensive in this function!
|
||||
|
||||
Returns True if the task actually did work, False if it exited early to prevent overlap
|
||||
"""
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r = get_redis_client()
|
||||
|
||||
lock_beat: redis.lock.Lock = r.lock(
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -697,107 +483,30 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
try:
|
||||
# prevent overlapping tasks
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return False
|
||||
return
|
||||
|
||||
# print current queue lengths
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r)
|
||||
n_indexing = celery_get_queue_length(
|
||||
DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
n_sync = celery_get_queue_length(
|
||||
DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery
|
||||
)
|
||||
n_deletion = celery_get_queue_length(
|
||||
DanswerCeleryQueues.CONNECTOR_DELETION, r_celery
|
||||
)
|
||||
n_pruning = celery_get_queue_length(
|
||||
DanswerCeleryQueues.CONNECTOR_PRUNING, r_celery
|
||||
)
|
||||
n_permissions_sync = celery_get_queue_length(
|
||||
DanswerCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
)
|
||||
|
||||
# do some cleanup before clearing fences
|
||||
# check the db for any outstanding index attempts
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
attempts: list[IndexAttempt] = []
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
|
||||
)
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
|
||||
)
|
||||
|
||||
for attempt in attempts:
|
||||
# if attempts exist in the db but we don't detect them in redis, mark them as failed
|
||||
fence_key = RedisConnectorIndex.fence_key_with_ids(
|
||||
attempt.connector_credential_pair_id, attempt.search_settings_id
|
||||
)
|
||||
if not r.exists(fence_key):
|
||||
failure_reason = (
|
||||
f"Unknown index attempt. Might be left over from a process restart: "
|
||||
f"index_attempt={attempt.id} "
|
||||
f"cc_pair={attempt.connector_credential_pair_id} "
|
||||
f"search_settings={attempt.search_settings_id}"
|
||||
)
|
||||
task_logger.warning(failure_reason)
|
||||
mark_attempt_failed(
|
||||
attempt.id, db_session, failure_reason=failure_reason
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
monitor_connector_deletion_taskset(key_bytes, r)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
monitor_document_set_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
monitor_usergroup_taskset = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"danswer.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
)
|
||||
monitor_usergroup_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
|
||||
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
|
||||
|
||||
# uncomment for debugging if needed
|
||||
# r_celery = celery_app.broker_connection().channel().client
|
||||
@@ -811,28 +520,24 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="vespa_metadata_sync_task",
|
||||
bind=True,
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
soft_time_limit=45,
|
||||
time_limit=60,
|
||||
max_retries=3,
|
||||
)
|
||||
def vespa_metadata_sync_task(
|
||||
self: Task, document_id: str, tenant_id: str | None
|
||||
) -> bool:
|
||||
def vespa_metadata_sync_task(self: Task, document_id: str) -> bool:
|
||||
task_logger.info(f"document_id={document_id}")
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
retry_index = RetryDocumentIndex(doc_index)
|
||||
|
||||
doc = get_document(document_id, db_session)
|
||||
if not doc:
|
||||
return False
|
||||
@@ -845,54 +550,24 @@ def vespa_metadata_sync_task(
|
||||
doc_access = get_access_for_document(
|
||||
document_id=document_id, db_session=db_session
|
||||
)
|
||||
|
||||
fields = VespaDocumentFields(
|
||||
update_request = UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
document_sets=update_doc_sets,
|
||||
access=doc_access,
|
||||
boost=doc.boost,
|
||||
hidden=doc.hidden,
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
chunks_affected = retry_index.update_single(document_id, fields)
|
||||
# update Vespa
|
||||
document_index.update(update_requests=[update_request])
|
||||
|
||||
# update db last. Worst case = we crash right before this and
|
||||
# the sync might repeat again later
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
|
||||
task_logger.info(
|
||||
f"tenant={tenant_id} doc={document_id} action=sync chunks={chunks_affected}"
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
f"SoftTimeLimitExceeded exception. tenant={tenant_id} doc={document_id}"
|
||||
)
|
||||
except Exception as ex:
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(
|
||||
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
|
||||
)
|
||||
|
||||
# only set the inner exception if it is of type Exception
|
||||
e_temp = ex.last_attempt.exception()
|
||||
if isinstance(e_temp, Exception):
|
||||
e = e_temp
|
||||
else:
|
||||
e = ex
|
||||
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
if e.response.status_code == HTTPStatus.BAD_REQUEST:
|
||||
task_logger.exception(
|
||||
f"Non-retryable HTTPStatusError: "
|
||||
f"tenant={tenant_id} "
|
||||
f"doc={document_id} "
|
||||
f"status={e.response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
task_logger.exception(
|
||||
f"Unexpected exception: tenant={tenant_id} doc={document_id}"
|
||||
)
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
|
||||
except Exception as e:
|
||||
task_logger.exception("Unexpected exception")
|
||||
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Factory stub for running celery worker / celery beat."""
|
||||
from danswer.background.celery.apps.beat import celery_app
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app = celery_app
|
||||
@@ -1,17 +0,0 @@
|
||||
"""Factory stub for running celery worker / celery beat.
|
||||
This code is different from the primary/beat stubs because there is no EE version to
|
||||
fetch. Port over the code in those files if we add an EE version of this worker."""
|
||||
from celery import Celery
|
||||
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
|
||||
def get_app() -> Celery:
|
||||
from danswer.background.celery.apps.heavy import celery_app
|
||||
|
||||
return celery_app
|
||||
|
||||
|
||||
app = get_app()
|
||||
@@ -1,17 +0,0 @@
|
||||
"""Factory stub for running celery worker / celery beat.
|
||||
This code is different from the primary/beat stubs because there is no EE version to
|
||||
fetch. Port over the code in those files if we add an EE version of this worker."""
|
||||
from celery import Celery
|
||||
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
|
||||
def get_app() -> Celery:
|
||||
from danswer.background.celery.apps.indexing import celery_app
|
||||
|
||||
return celery_app
|
||||
|
||||
|
||||
app = get_app()
|
||||
@@ -1,17 +0,0 @@
|
||||
"""Factory stub for running celery worker / celery beat.
|
||||
This code is different from the primary/beat stubs because there is no EE version to
|
||||
fetch. Port over the code in those files if we add an EE version of this worker."""
|
||||
from celery import Celery
|
||||
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
|
||||
def get_app() -> Celery:
|
||||
from danswer.background.celery.apps.light import celery_app
|
||||
|
||||
return celery_app
|
||||
|
||||
|
||||
app = get_app()
|
||||
@@ -11,8 +11,7 @@ from typing import Any
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -29,26 +28,16 @@ JobStatusType = (
|
||||
def _initializer(
|
||||
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Initialize the child process with a fresh SQLAlchemy Engine.
|
||||
"""Ensure the parent proc's database connections are not touched
|
||||
in the new connection pool
|
||||
|
||||
Based on SQLAlchemy's recommendations to handle multiprocessing:
|
||||
Based on the recommended approach in the SQLAlchemy docs found:
|
||||
https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
logger.info("Initializing spawned worker child process.")
|
||||
|
||||
# Reset the engine in the child process
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
# Optionally set a custom app name for database logging purposes
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
|
||||
|
||||
# Initialize a new engine with desired parameters
|
||||
SqlEngine.init_engine(pool_size=4, max_overflow=12, pool_recycle=60)
|
||||
|
||||
# Proceed with executing the target function
|
||||
get_sqlalchemy_engine().dispose(close=False)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import time
|
||||
import traceback
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -19,12 +17,13 @@ from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress
|
||||
from danswer.db.index_attempt import mark_attempt_partially_succeeded
|
||||
from danswer.db.index_attempt import mark_attempt_succeeded
|
||||
from danswer.db.index_attempt import transition_attempt_to_in_progress
|
||||
from danswer.db.index_attempt import update_docs_indexed
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
@@ -33,8 +32,8 @@ from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.embedder import DefaultIndexingEmbedder
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
|
||||
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.utils.logger import IndexAttemptSingleton
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.logger import TaskAttemptSingleton
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -42,25 +41,11 @@ logger = setup_logger()
|
||||
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
|
||||
|
||||
|
||||
class RunIndexingCallbackInterface(ABC):
|
||||
"""Defines a callback interface to be passed to
|
||||
to run_indexing_entrypoint."""
|
||||
|
||||
@abstractmethod
|
||||
def should_stop(self) -> bool:
|
||||
"""Signal to stop the looping function in flight."""
|
||||
|
||||
@abstractmethod
|
||||
def progress(self, amount: int) -> None:
|
||||
"""Send progress updates to the caller."""
|
||||
|
||||
|
||||
def _get_connector_runner(
|
||||
db_session: Session,
|
||||
attempt: IndexAttempt,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
tenant_id: str | None,
|
||||
) -> ConnectorRunner:
|
||||
"""
|
||||
NOTE: `start_time` and `end_time` are only used for poll connectors
|
||||
@@ -78,7 +63,6 @@ def _get_connector_runner(
|
||||
input_type=task,
|
||||
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
|
||||
credential=attempt.connector_credential_pair.credential,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
@@ -105,26 +89,15 @@ def _get_connector_runner(
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt: IndexAttempt,
|
||||
tenant_id: str | None,
|
||||
callback: RunIndexingCallbackInterface | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
1. Get documents which are either new or updated from specified application
|
||||
2. Embed and index these documents into the chosen datastore (vespa)
|
||||
3. Updates Postgres to record the indexed documents + the outcome of this run
|
||||
|
||||
TODO: do not change index attempt statuses here ... instead, set signals in redis
|
||||
and allow the monitor function to clean them up
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
if index_attempt.search_settings is None:
|
||||
raise ValueError(
|
||||
"Search settings must be set for indexing. This should not be possible."
|
||||
)
|
||||
|
||||
search_settings = index_attempt.search_settings
|
||||
|
||||
index_name = search_settings.index_name
|
||||
|
||||
# Only update cc-pair status for primary index jobs
|
||||
@@ -156,7 +129,6 @@ def _run_indexing(
|
||||
or (search_settings.status == IndexModelStatus.FUTURE)
|
||||
),
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
db_cc_pair = index_attempt.connector_credential_pair
|
||||
@@ -213,7 +185,6 @@ def _run_indexing(
|
||||
attempt=index_attempt,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
@@ -226,12 +197,7 @@ def _run_indexing(
|
||||
# index being built. We want to populate it even for paused connectors
|
||||
# Often paused connectors are sources that aren't updated frequently but the
|
||||
# contents still need to be initially pulled.
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("Connector stop signal detected")
|
||||
|
||||
# TODO: should we move this into the above callback instead?
|
||||
db_session.refresh(db_cc_pair)
|
||||
db_session.refresh(db_connector)
|
||||
if (
|
||||
(
|
||||
db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
|
||||
@@ -246,9 +212,7 @@ def _run_indexing(
|
||||
db_session.refresh(index_attempt)
|
||||
if index_attempt.status != IndexingStatus.IN_PROGRESS:
|
||||
# Likely due to user manually disabling it or model swap
|
||||
raise RuntimeError(
|
||||
f"Index Attempt was canceled, status is {index_attempt.status}"
|
||||
)
|
||||
raise RuntimeError("Index Attempt was canceled")
|
||||
|
||||
batch_description = []
|
||||
for doc in doc_batch:
|
||||
@@ -268,8 +232,6 @@ def _run_indexing(
|
||||
logger.debug(f"Indexing batch of documents: {batch_description}")
|
||||
|
||||
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
|
||||
|
||||
# real work happens here!
|
||||
new_docs, total_batch_chunks = indexing_pipeline(
|
||||
document_batch=doc_batch,
|
||||
index_attempt_metadata=index_attempt_md,
|
||||
@@ -288,9 +250,6 @@ def _run_indexing(
|
||||
# be inaccurate
|
||||
db_session.commit()
|
||||
|
||||
if callback:
|
||||
callback.progress(len(doc_batch))
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
update_docs_indexed(
|
||||
db_session=db_session,
|
||||
@@ -337,7 +296,7 @@ def _run_indexing(
|
||||
or index_attempt.status != IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt.id,
|
||||
index_attempt,
|
||||
db_session,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
@@ -372,7 +331,7 @@ def _run_indexing(
|
||||
and index_attempt_md.num_exceptions >= batch_num
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt.id,
|
||||
index_attempt,
|
||||
db_session,
|
||||
failure_reason="All batches exceptioned.",
|
||||
)
|
||||
@@ -414,45 +373,69 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
|
||||
def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexAttempt:
|
||||
# make sure that the index attempt can't change in between checking the
|
||||
# status and marking it as in_progress. This setting will be discarded
|
||||
# after the next commit:
|
||||
# https://docs.sqlalchemy.org/en/20/orm/session_transaction.html#setting-isolation-for-individual-transactions
|
||||
db_session.connection(execution_options={"isolation_level": "SERIALIZABLE"}) # type: ignore
|
||||
|
||||
attempt = get_index_attempt(
|
||||
db_session=db_session,
|
||||
index_attempt_id=index_attempt_id,
|
||||
)
|
||||
|
||||
if attempt is None:
|
||||
raise RuntimeError(f"Unable to find IndexAttempt for ID '{index_attempt_id}'")
|
||||
|
||||
if attempt.status != IndexingStatus.NOT_STARTED:
|
||||
raise RuntimeError(
|
||||
f"Indexing attempt with ID '{index_attempt_id}' is not in NOT_STARTED status. "
|
||||
f"Current status is '{attempt.status}'."
|
||||
)
|
||||
|
||||
# only commit once, to make sure this all happens in a single transaction
|
||||
mark_attempt_in_progress(attempt, db_session)
|
||||
|
||||
return attempt
|
||||
|
||||
|
||||
def run_indexing_entrypoint(
|
||||
index_attempt_id: int,
|
||||
tenant_id: str | None,
|
||||
connector_credential_pair_id: int,
|
||||
is_ee: bool = False,
|
||||
callback: RunIndexingCallbackInterface | None = None,
|
||||
index_attempt_id: int, connector_credential_pair_id: int, is_ee: bool = False
|
||||
) -> None:
|
||||
"""Entrypoint for indexing run when using dask distributed.
|
||||
Wraps the actual logic in a `try` block so that we can catch any exceptions
|
||||
and mark the attempt as failed."""
|
||||
|
||||
try:
|
||||
if is_ee:
|
||||
global_version.set_ee()
|
||||
|
||||
# set the indexing attempt ID so that all log messages from this process
|
||||
# will have it added as a prefix
|
||||
TaskAttemptSingleton.set_cc_and_index_id(
|
||||
IndexAttemptSingleton.set_cc_and_index_id(
|
||||
index_attempt_id, connector_credential_pair_id
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# make sure that it is valid to run this indexing attempt + mark it
|
||||
# as in progress
|
||||
attempt = _prepare_index_attempt(db_session, index_attempt_id)
|
||||
|
||||
logger.info(
|
||||
f"Indexing starting for tenant {tenant_id}: "
|
||||
if tenant_id is not None
|
||||
else ""
|
||||
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"Indexing starting: "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
|
||||
_run_indexing(db_session, attempt, tenant_id, callback)
|
||||
_run_indexing(db_session, attempt)
|
||||
|
||||
logger.info(
|
||||
f"Indexing finished for tenant {tenant_id}: "
|
||||
if tenant_id is not None
|
||||
else ""
|
||||
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"Indexing finished: "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}"
|
||||
)
|
||||
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")
|
||||
|
||||
@@ -14,6 +14,15 @@ from danswer.db.tasks import mark_task_start
|
||||
from danswer.db.tasks import register_task
|
||||
|
||||
|
||||
def name_cc_prune_task(
|
||||
connector_id: int | None = None, credential_id: int | None = None
|
||||
) -> str:
|
||||
task_name = f"prune_connector_credential_pair_{connector_id}_{credential_id}"
|
||||
if not connector_id or not credential_id:
|
||||
task_name = "prune_connector_credential_pair"
|
||||
return task_name
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Callable)
|
||||
|
||||
|
||||
|
||||
495
backend/danswer/background/update.py
Executable file
495
backend/danswer/background/update.py
Executable file
@@ -0,0 +1,495 @@
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import dask
|
||||
from dask.distributed import Client
|
||||
from dask.distributed import Future
|
||||
from distributed import LocalCluster
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.indexing.dask_utils import ResourceLogger
|
||||
from danswer.background.indexing.job_client import SimpleJob
|
||||
from danswer.background.indexing.job_client import SimpleJobClient
|
||||
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
||||
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||
from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
|
||||
from danswer.db.connector import fetch_connectors
|
||||
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import get_inprogress_index_attempts
|
||||
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from danswer.db.index_attempt import get_not_started_index_attempts
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import LOG_LEVEL
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# If the indexing dies, it's most likely due to resource constraints,
|
||||
# restarting just delays the eventual failure, not useful to the user
|
||||
dask.config.set({"distributed.scheduler.allowed-failures": 0})
|
||||
|
||||
_UNEXPECTED_STATE_FAILURE_REASON = (
|
||||
"Stopped mid run, likely due to the background process being killed"
|
||||
)
|
||||
|
||||
|
||||
def _should_create_new_indexing(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
last_index: IndexAttempt | None,
|
||||
search_settings_instance: SearchSettings,
|
||||
secondary_index_building: bool,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
connector = cc_pair.connector
|
||||
|
||||
# don't kick off indexing for `NOT_APPLICABLE` sources
|
||||
if connector.source == DocumentSource.NOT_APPLICABLE:
|
||||
return False
|
||||
|
||||
# User can still manually create single indexing attempts via the UI for the
|
||||
# currently in use index
|
||||
if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
if (
|
||||
search_settings_instance.status == IndexModelStatus.PRESENT
|
||||
and secondary_index_building
|
||||
):
|
||||
return False
|
||||
|
||||
# When switching over models, always index at least once
|
||||
if search_settings_instance.status == IndexModelStatus.FUTURE:
|
||||
if last_index:
|
||||
# No new index if the last index attempt succeeded
|
||||
# Once is enough. The model will never be able to swap otherwise.
|
||||
if last_index.status == IndexingStatus.SUCCESS:
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is waiting to start
|
||||
if last_index.status == IndexingStatus.NOT_STARTED:
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is running
|
||||
if last_index.status == IndexingStatus.IN_PROGRESS:
|
||||
return False
|
||||
else:
|
||||
if (
|
||||
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
|
||||
): # Ingestion API
|
||||
return False
|
||||
return True
|
||||
|
||||
# If the connector is paused or is the ingestion API, don't index
|
||||
# NOTE: during an embedding model switch over, the following logic
|
||||
# is bypassed by the above check for a future model
|
||||
if (
|
||||
not cc_pair.status.is_active()
|
||||
or connector.id == 0
|
||||
or connector.source == DocumentSource.INGESTION_API
|
||||
):
|
||||
return False
|
||||
|
||||
if not last_index:
|
||||
return True
|
||||
|
||||
if connector.refresh_freq is None:
|
||||
return False
|
||||
|
||||
# Only one scheduled/ongoing job per connector at a time
|
||||
# this prevents cases where
|
||||
# (1) the "latest" index_attempt is scheduled so we show
|
||||
# that in the UI despite another index_attempt being in-progress
|
||||
# (2) multiple scheduled index_attempts at a time
|
||||
if (
|
||||
last_index.status == IndexingStatus.NOT_STARTED
|
||||
or last_index.status == IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
return False
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
time_since_index = current_db_time - last_index.time_updated
|
||||
return time_since_index.total_seconds() >= connector.refresh_freq
|
||||
|
||||
|
||||
def _mark_run_failed(
|
||||
db_session: Session, index_attempt: IndexAttempt, failure_reason: str
|
||||
) -> None:
|
||||
"""Marks the `index_attempt` row as failed + updates the `
|
||||
connector_credential_pair` to reflect that the run failed"""
|
||||
logger.warning(
|
||||
f"Marking in-progress attempt 'connector: {index_attempt.connector_credential_pair.connector_id}, "
|
||||
f"credential: {index_attempt.connector_credential_pair.credential_id}' as failed due to {failure_reason}"
|
||||
)
|
||||
mark_attempt_failed(
|
||||
index_attempt=index_attempt,
|
||||
db_session=db_session,
|
||||
failure_reason=failure_reason,
|
||||
)
|
||||
|
||||
|
||||
"""Main funcs"""
|
||||
|
||||
|
||||
def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
"""Creates new indexing jobs for each connector / credential pair which is:
|
||||
1. Enabled
|
||||
2. `refresh_frequency` time has passed since the last indexing run for this pair
|
||||
3. There is not already an ongoing indexing attempt for this pair
|
||||
"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
ongoing: set[tuple[int | None, int]] = set()
|
||||
for attempt_id in existing_jobs:
|
||||
attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=attempt_id
|
||||
)
|
||||
if attempt is None:
|
||||
logger.error(
|
||||
f"Unable to find IndexAttempt for ID '{attempt_id}' when creating "
|
||||
"indexing jobs"
|
||||
)
|
||||
continue
|
||||
ongoing.add(
|
||||
(
|
||||
attempt.connector_credential_pair_id,
|
||||
attempt.search_settings_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
search_settings = [primary_search_settings]
|
||||
|
||||
# Check for secondary search settings
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings is not None:
|
||||
# If secondary settings exist, add them to the list
|
||||
search_settings.append(secondary_search_settings)
|
||||
|
||||
all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair in all_connector_credential_pairs:
|
||||
for search_settings_instance in search_settings:
|
||||
# Check if there is an ongoing indexing attempt for this connector credential pair
|
||||
if (cc_pair.id, search_settings_instance.id) in ongoing:
|
||||
continue
|
||||
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
if not _should_create_new_indexing(
|
||||
cc_pair=cc_pair,
|
||||
last_index=last_attempt,
|
||||
search_settings_instance=search_settings_instance,
|
||||
secondary_index_building=len(search_settings) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
create_index_attempt(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
|
||||
|
||||
def cleanup_indexing_jobs(
|
||||
existing_jobs: dict[int, Future | SimpleJob],
|
||||
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
|
||||
) -> dict[int, Future | SimpleJob]:
|
||||
existing_jobs_copy = existing_jobs.copy()
|
||||
# clean up completed jobs
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for attempt_id, job in existing_jobs.items():
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=attempt_id
|
||||
)
|
||||
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if not job.done():
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
logger.error(job.exception())
|
||||
|
||||
job.release()
|
||||
del existing_jobs_copy[attempt_id]
|
||||
|
||||
if not index_attempt:
|
||||
logger.error(
|
||||
f"Unable to find IndexAttempt for ID '{attempt_id}' when cleaning "
|
||||
"up indexing jobs"
|
||||
)
|
||||
continue
|
||||
|
||||
if (
|
||||
index_attempt.status == IndexingStatus.IN_PROGRESS
|
||||
or job.status == "error"
|
||||
):
|
||||
_mark_run_failed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
||||
)
|
||||
|
||||
# clean up in-progress jobs that were never completed
|
||||
connectors = fetch_connectors(db_session)
|
||||
for connector in connectors:
|
||||
in_progress_indexing_attempts = get_inprogress_index_attempts(
|
||||
connector.id, db_session
|
||||
)
|
||||
for index_attempt in in_progress_indexing_attempts:
|
||||
if index_attempt.id in existing_jobs:
|
||||
# If index attempt is canceled, stop the run
|
||||
if index_attempt.status == IndexingStatus.FAILED:
|
||||
existing_jobs[index_attempt.id].cancel()
|
||||
# check to see if the job has been updated in last `timeout_hours` hours, if not
|
||||
# assume it to frozen in some bad state and just mark it as failed. Note: this relies
|
||||
# on the fact that the `time_updated` field is constantly updated every
|
||||
# batch of documents indexed
|
||||
current_db_time = get_db_current_time(db_session=db_session)
|
||||
time_since_update = current_db_time - index_attempt.time_updated
|
||||
if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
|
||||
existing_jobs[index_attempt.id].cancel()
|
||||
_mark_run_failed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
failure_reason="Indexing run frozen - no updates in the last three hours. "
|
||||
"The run will be re-attempted at next scheduled indexing time.",
|
||||
)
|
||||
else:
|
||||
# If job isn't known, simply mark it as failed
|
||||
_mark_run_failed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
||||
)
|
||||
|
||||
return existing_jobs_copy
|
||||
|
||||
|
||||
def kickoff_indexing_jobs(
|
||||
existing_jobs: dict[int, Future | SimpleJob],
|
||||
client: Client | SimpleJobClient,
|
||||
secondary_client: Client | SimpleJobClient,
|
||||
) -> dict[int, Future | SimpleJob]:
|
||||
existing_jobs_copy = existing_jobs.copy()
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
# Don't include jobs waiting in the Dask queue that just haven't started running
|
||||
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
|
||||
with Session(engine) as db_session:
|
||||
# get_not_started_index_attempts orders its returned results from oldest to newest
|
||||
# we must process attempts in a FIFO manner to prevent connector starvation
|
||||
new_indexing_attempts = [
|
||||
(attempt, attempt.search_settings)
|
||||
for attempt in get_not_started_index_attempts(db_session)
|
||||
if attempt.id not in existing_jobs
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(new_indexing_attempts)} new indexing task(s).")
|
||||
|
||||
if not new_indexing_attempts:
|
||||
return existing_jobs
|
||||
|
||||
indexing_attempt_count = 0
|
||||
|
||||
primary_client_full = False
|
||||
secondary_client_full = False
|
||||
for attempt, search_settings in new_indexing_attempts:
|
||||
if primary_client_full and secondary_client_full:
|
||||
break
|
||||
|
||||
use_secondary_index = (
|
||||
search_settings.status == IndexModelStatus.FUTURE
|
||||
if search_settings is not None
|
||||
else False
|
||||
)
|
||||
if attempt.connector_credential_pair.connector is None:
|
||||
logger.warning(
|
||||
f"Skipping index attempt as Connector has been deleted: {attempt}"
|
||||
)
|
||||
with Session(engine) as db_session:
|
||||
mark_attempt_failed(
|
||||
attempt, db_session, failure_reason="Connector is null"
|
||||
)
|
||||
continue
|
||||
if attempt.connector_credential_pair.credential is None:
|
||||
logger.warning(
|
||||
f"Skipping index attempt as Credential has been deleted: {attempt}"
|
||||
)
|
||||
with Session(engine) as db_session:
|
||||
mark_attempt_failed(
|
||||
attempt, db_session, failure_reason="Credential is null"
|
||||
)
|
||||
continue
|
||||
|
||||
if not use_secondary_index:
|
||||
if not primary_client_full:
|
||||
run = client.submit(
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
attempt.connector_credential_pair_id,
|
||||
global_version.get_is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
if not run:
|
||||
primary_client_full = True
|
||||
else:
|
||||
if not secondary_client_full:
|
||||
run = secondary_client.submit(
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
attempt.connector_credential_pair_id,
|
||||
global_version.get_is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
if not run:
|
||||
secondary_client_full = True
|
||||
|
||||
if run:
|
||||
if indexing_attempt_count == 0:
|
||||
logger.info(
|
||||
f"Indexing dispatch starts: pending={len(new_indexing_attempts)}"
|
||||
)
|
||||
|
||||
indexing_attempt_count += 1
|
||||
secondary_str = " (secondary index)" if use_secondary_index else ""
|
||||
logger.info(
|
||||
f"Indexing dispatched{secondary_str}: "
|
||||
f"attempt_id={attempt.id} "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.credential_id}'"
|
||||
)
|
||||
existing_jobs_copy[attempt.id] = run
|
||||
|
||||
if indexing_attempt_count > 0:
|
||||
logger.info(
|
||||
f"Indexing dispatch results: "
|
||||
f"initial_pending={len(new_indexing_attempts)} "
|
||||
f"started={indexing_attempt_count} "
|
||||
f"remaining={len(new_indexing_attempts) - indexing_attempt_count}"
|
||||
)
|
||||
|
||||
return existing_jobs_copy
|
||||
|
||||
|
||||
def update_loop(
|
||||
delay: int = 10,
|
||||
num_workers: int = NUM_INDEXING_WORKERS,
|
||||
num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
|
||||
) -> None:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
check_index_swap(db_session=db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
|
||||
if search_settings.provider_type is None:
|
||||
logger.notice("Running a first inference to warm up embedding model")
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
logger.notice("First inference complete.")
|
||||
|
||||
client_primary: Client | SimpleJobClient
|
||||
client_secondary: Client | SimpleJobClient
|
||||
if DASK_JOB_CLIENT_ENABLED:
|
||||
cluster_primary = LocalCluster(
|
||||
n_workers=num_workers,
|
||||
threads_per_worker=1,
|
||||
# there are warning about high memory usage + "Event loop unresponsive"
|
||||
# which are not relevant to us since our workers are expected to use a
|
||||
# lot of memory + involve CPU intensive tasks that will not relinquish
|
||||
# the event loop
|
||||
silence_logs=logging.ERROR,
|
||||
)
|
||||
cluster_secondary = LocalCluster(
|
||||
n_workers=num_secondary_workers,
|
||||
threads_per_worker=1,
|
||||
silence_logs=logging.ERROR,
|
||||
)
|
||||
client_primary = Client(cluster_primary)
|
||||
client_secondary = Client(cluster_secondary)
|
||||
if LOG_LEVEL.lower() == "debug":
|
||||
client_primary.register_worker_plugin(ResourceLogger())
|
||||
else:
|
||||
client_primary = SimpleJobClient(n_workers=num_workers)
|
||||
client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
|
||||
|
||||
existing_jobs: dict[int, Future | SimpleJob] = {}
|
||||
|
||||
logger.notice("Startup complete. Waiting for indexing jobs...")
|
||||
while True:
|
||||
start = time.time()
|
||||
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
|
||||
logger.debug(f"Running update, current UTC time: {start_time_utc}")
|
||||
|
||||
if existing_jobs:
|
||||
# TODO: make this debug level once the "no jobs are being scheduled" issue is resolved
|
||||
logger.debug(
|
||||
"Found existing indexing jobs: "
|
||||
f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}"
|
||||
)
|
||||
|
||||
try:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
check_index_swap(db_session)
|
||||
existing_jobs = cleanup_indexing_jobs(existing_jobs=existing_jobs)
|
||||
create_indexing_jobs(existing_jobs=existing_jobs)
|
||||
existing_jobs = kickoff_indexing_jobs(
|
||||
existing_jobs=existing_jobs,
|
||||
client=client_primary,
|
||||
secondary_client=client_secondary,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to run update due to {e}")
|
||||
sleep_time = delay - (time.time() - start)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
|
||||
def update__main() -> None:
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
# initialize the Postgres connection pool
|
||||
SqlEngine.set_app_name(POSTGRES_INDEXER_APP_NAME)
|
||||
|
||||
logger.notice("Starting indexing service")
|
||||
update_loop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
update__main()
|
||||
@@ -1,8 +1,6 @@
|
||||
import re
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi.datastructures import Headers
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
@@ -35,7 +33,7 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
|
||||
|
||||
|
||||
def create_chat_chain(
|
||||
chat_session_id: UUID,
|
||||
chat_session_id: int,
|
||||
db_session: Session,
|
||||
prefetch_tool_calls: bool = True,
|
||||
# Optional id at which we finish processing
|
||||
@@ -168,31 +166,3 @@ def reorganize_citations(
|
||||
new_citation_info[citation.citation_num] = citation
|
||||
|
||||
return new_answer, list(new_citation_info.values())
|
||||
|
||||
|
||||
def extract_headers(
|
||||
headers: dict[str, str] | Headers, pass_through_headers: list[str] | None
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Extract headers specified in pass_through_headers from input headers.
|
||||
Handles both dict and FastAPI Headers objects, accounting for lowercase keys.
|
||||
|
||||
Args:
|
||||
headers: Input headers as dict or Headers object.
|
||||
|
||||
Returns:
|
||||
dict: Filtered headers based on pass_through_headers.
|
||||
"""
|
||||
if not pass_through_headers:
|
||||
return {}
|
||||
|
||||
extracted_headers: dict[str, str] = {}
|
||||
for key in pass_through_headers:
|
||||
if key in headers:
|
||||
extracted_headers[key] = headers[key]
|
||||
else:
|
||||
# fastapi makes all header keys lowercase, handling that here
|
||||
lowercase_key = key.lower()
|
||||
if lowercase_key in headers:
|
||||
extracted_headers[lowercase_key] = headers[lowercase_key]
|
||||
return extracted_headers
|
||||
|
||||
@@ -10,7 +10,7 @@ from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SearchResponse
|
||||
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
from danswer.tools.custom.base_tool_types import ToolResultType
|
||||
|
||||
|
||||
class LlmDoc(BaseModel):
|
||||
@@ -156,7 +156,7 @@ class QAResponse(SearchResponse, DanswerAnswer):
|
||||
error_msg: str | None = None
|
||||
|
||||
|
||||
class FileChatDisplay(BaseModel):
|
||||
class ImageGenerationDisplay(BaseModel):
|
||||
file_ids: list[str]
|
||||
|
||||
|
||||
@@ -170,7 +170,7 @@ AnswerQuestionPossibleReturn = (
|
||||
| DanswerQuotes
|
||||
| CitationInfo
|
||||
| DanswerContexts
|
||||
| FileChatDisplay
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| StreamingError
|
||||
| StreamStopInfo
|
||||
|
||||
@@ -41,15 +41,6 @@ personas:
|
||||
icon_color: "#6FB1FF"
|
||||
display_priority: 1
|
||||
is_visible: true
|
||||
starter_messages:
|
||||
- name: "Give me an overview of what's here"
|
||||
message: "Sample some documents and tell me what you find."
|
||||
- name: "Use AI to solve a work related problem"
|
||||
message: "Ask me what problem I would like to solve, then search the knowledge base to help me find a solution."
|
||||
- name: "Find updates on a topic of interest"
|
||||
message: "Once I provide a topic, retrieve related documents and tell me when there was last activity on the topic if available."
|
||||
- name: "Surface contradictions"
|
||||
message: "Have me choose a subject. Once I have provided it, check against the knowledge base and point out any inconsistencies. For all your following responses, focus on identifying contradictions."
|
||||
|
||||
- id: 1
|
||||
name: "General"
|
||||
@@ -66,15 +57,6 @@ personas:
|
||||
icon_color: "#FF6F6F"
|
||||
display_priority: 0
|
||||
is_visible: true
|
||||
starter_messages:
|
||||
- name: "Summarize a document"
|
||||
message: "If I have provided a document please summarize it for me. If not, please ask me to upload a document either by dragging it into the input bar or clicking the +file icon."
|
||||
- name: "Help me with coding"
|
||||
message: 'Write me a "Hello World" script in 5 random languages to show off the functionality.'
|
||||
- name: "Draft a professional email"
|
||||
message: "Help me craft a professional email. Let's establish the context and the anticipated outcomes of the email before proposing a draft."
|
||||
- name: "Learn something new"
|
||||
message: "What is the difference between a Gantt chart, a Burndown chart and a Kanban board?"
|
||||
|
||||
- id: 2
|
||||
name: "Paraphrase"
|
||||
@@ -91,15 +73,7 @@ personas:
|
||||
icon_color: "#6FFF8D"
|
||||
display_priority: 2
|
||||
is_visible: false
|
||||
starter_messages:
|
||||
- name: "Document Search"
|
||||
message: "Hi! Could you help me find information about our team structure and reporting lines from our internal documents?"
|
||||
- name: "Process Verification"
|
||||
message: "Hello! I need to understand our project approval process. Could you find the exact steps from our documentation?"
|
||||
- name: "Technical Documentation"
|
||||
message: "Hi there! I'm looking for information about our deployment procedures. Can you find the specific steps from our technical guides?"
|
||||
- name: "Policy Reference"
|
||||
message: "Hello! Could you help me find our official guidelines about client communication? I need the exact wording from our documentation."
|
||||
|
||||
|
||||
- id: 3
|
||||
name: "Art"
|
||||
@@ -112,17 +86,8 @@ personas:
|
||||
llm_filter_extraction: false
|
||||
recency_bias: "no_decay"
|
||||
document_sets: []
|
||||
icon_shape: 234124
|
||||
icon_shape: 234124
|
||||
icon_color: "#9B59B6"
|
||||
image_generation: true
|
||||
image_generation: true
|
||||
display_priority: 3
|
||||
is_visible: true
|
||||
starter_messages:
|
||||
- name: "Create visuals for a presentation"
|
||||
message: "Generate someone presenting a graph which clearly demonstrates an upwards trajectory."
|
||||
- name: "Find inspiration for a marketing campaign"
|
||||
message: "Generate an image of two happy individuals sipping on a soda drink in a glass bottle."
|
||||
- name: "Visualize a product design"
|
||||
message: "I want to add a search bar to my Iphone app. Generate me generic examples of how other apps implement this."
|
||||
- name: "Generate a humorous image response"
|
||||
message: "My teammate just made a silly mistake and I want to respond with a facepalm. Can you generate me one?"
|
||||
|
||||
@@ -11,18 +11,19 @@ from danswer.chat.models import AllCitations
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import CustomToolResponse
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import FileChatDisplay
|
||||
from danswer.chat.models import FinalUsedContextDocsResponse
|
||||
from danswer.chat.models import ImageGenerationDisplay
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import MessageResponseIDInfo
|
||||
from danswer.chat.models import MessageSpecificCitations
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.db.chat import attach_files_to_chat_message
|
||||
from danswer.db.chat import create_db_search_doc
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
@@ -35,6 +36,7 @@ from danswer.db.chat import reserve_message_id
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.models import SearchDoc as DbSearchDoc
|
||||
from danswer.db.models import ToolCall
|
||||
from danswer.db.models import User
|
||||
@@ -54,13 +56,14 @@ from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import litellm_exception_to_error_msg
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.search.retrieval.search_runner import inference_sections_from_ids
|
||||
from danswer.search.utils import chunks_or_sections_to_search_docs
|
||||
from danswer.search.utils import dedupe_documents
|
||||
@@ -69,48 +72,35 @@ from danswer.search.utils import relevant_sections_to_indices
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.built_in_tools import get_built_in_tool_by_id
|
||||
from danswer.tools.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
from danswer.tools.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool_constructor import construct_tools
|
||||
from danswer.tools.tool_constructor import CustomToolConfig
|
||||
from danswer.tools.tool_constructor import ImageGenerationToolConfig
|
||||
from danswer.tools.tool_constructor import InternetSearchToolConfig
|
||||
from danswer.tools.tool_constructor import SearchToolConfig
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
CUSTOM_TOOL_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
IMAGE_GENERATION_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationResponse,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
from danswer.tools.internet_search.internet_search_tool import (
|
||||
INTERNET_SEARCH_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
from danswer.tools.internet_search.internet_search_tool import (
|
||||
internet_search_response_to_search_docs,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchResponse,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchTool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SECTION_RELEVANCE_LIST_ID,
|
||||
)
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
|
||||
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
||||
from danswer.tools.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.tools.utils import compute_all_tool_tokens
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
|
||||
@@ -261,11 +251,10 @@ ChatPacket = (
|
||||
| DanswerAnswerPiece
|
||||
| AllCitations
|
||||
| CitationInfo
|
||||
| FileChatDisplay
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| MessageSpecificCitations
|
||||
| MessageResponseIDInfo
|
||||
| StreamStopInfo
|
||||
)
|
||||
ChatPacketStream = Iterator[ChatPacket]
|
||||
|
||||
@@ -281,8 +270,8 @@ def stream_chat_message_objects(
|
||||
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||
# if specified, uses the last user message and does not create a new user message based
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
enforce_chat_session_id_for_search_docs: bool = True,
|
||||
) -> ChatPacketStream:
|
||||
@@ -292,9 +281,6 @@ def stream_chat_message_objects(
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
"""
|
||||
use_existing_user_message = new_msg_req.use_existing_user_message
|
||||
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
|
||||
|
||||
# Currently surrounding context is not supported for chat
|
||||
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
|
||||
new_msg_req.chunks_above = 0
|
||||
@@ -416,20 +402,12 @@ def stream_chat_message_objects(
|
||||
final_msg, history_msgs = create_chat_chain(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
if existing_assistant_message_id is None:
|
||||
if final_msg.message_type != MessageType.USER:
|
||||
raise RuntimeError(
|
||||
"The last message was not a user message. Cannot call "
|
||||
"`stream_chat_message_objects` with `is_regenerate=True` "
|
||||
"when the last message is not a user message."
|
||||
)
|
||||
else:
|
||||
if final_msg.id != existing_assistant_message_id:
|
||||
raise RuntimeError(
|
||||
"The last message was not the existing assistant message. "
|
||||
f"Final message id: {final_msg.id}, "
|
||||
f"existing assistant message id: {existing_assistant_message_id}"
|
||||
)
|
||||
if final_msg.message_type != MessageType.USER:
|
||||
raise RuntimeError(
|
||||
"The last message was not a user message. Cannot call "
|
||||
"`stream_chat_message_objects` with `is_regenerate=True` "
|
||||
"when the last message is not a user message."
|
||||
)
|
||||
|
||||
# Disable Query Rephrasing for the first message
|
||||
# This leads to a better first response since the LLM rephrasing the question
|
||||
@@ -500,19 +478,13 @@ def stream_chat_message_objects(
|
||||
),
|
||||
max_window_percentage=max_document_percentage,
|
||||
)
|
||||
|
||||
# we don't need to reserve a message id if we're using an existing assistant message
|
||||
reserved_message_id = (
|
||||
final_msg.id
|
||||
if existing_assistant_message_id is not None
|
||||
else reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=user_message.id
|
||||
if user_message is not None
|
||||
else parent_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
reserved_message_id = reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=user_message.id
|
||||
if user_message is not None
|
||||
else parent_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
yield MessageResponseIDInfo(
|
||||
user_message_id=user_message.id if user_message else None,
|
||||
@@ -527,13 +499,7 @@ def stream_chat_message_objects(
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
chat_session_id=chat_session_id,
|
||||
# if we're using an existing assistant message, then this will just be an
|
||||
# update operation, in which case the parent should be the parent of
|
||||
# the latest. If we're creating a new assistant message, then the parent
|
||||
# should be the latest message (latest user message)
|
||||
parent_message=(
|
||||
final_msg if existing_assistant_message_id is None else parent_message
|
||||
),
|
||||
parent_message=final_msg,
|
||||
prompt_id=prompt_id,
|
||||
overridden_model=overridden_model,
|
||||
# message=,
|
||||
@@ -545,7 +511,6 @@ def stream_chat_message_objects(
|
||||
# reference_docs=,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
reserved_message_id=reserved_message_id,
|
||||
)
|
||||
|
||||
if not final_msg.prompt:
|
||||
@@ -561,53 +526,122 @@ def stream_chat_message_objects(
|
||||
if not persona
|
||||
else PromptConfig.from_model(persona.prompts[0])
|
||||
)
|
||||
answer_style_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig(
|
||||
all_docs_useful=selected_db_search_docs is not None
|
||||
),
|
||||
document_pruning_config=document_pruning_config,
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
)
|
||||
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
prompt_config=prompt_config,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
retrieval_options=retrieval_options or RetrievalDetails(),
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
latest_query_files=latest_query_files,
|
||||
),
|
||||
internet_search_tool_config=InternetSearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
),
|
||||
image_generation_tool_config=ImageGenerationToolConfig(
|
||||
additional_headers=litellm_additional_headers,
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
),
|
||||
)
|
||||
# find out what tools to use
|
||||
search_tool: SearchTool | None = None
|
||||
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
|
||||
for db_tool_model in persona.tools:
|
||||
# handle in-code tools specially
|
||||
if db_tool_model.in_code_tool_id:
|
||||
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
|
||||
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
persona=persona,
|
||||
retrieval_options=retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
evaluation_type=LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
img_generation_llm_config: LLMConfig | None = None
|
||||
if (
|
||||
llm
|
||||
and llm.config.api_key
|
||||
and llm.config.model_provider == "openai"
|
||||
):
|
||||
img_generation_llm_config = llm.config
|
||||
else:
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
openai_provider = next(
|
||||
iter(
|
||||
[
|
||||
llm_provider
|
||||
for llm_provider in llm_providers
|
||||
if llm_provider.provider == "openai"
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not openai_provider or not openai_provider.api_key:
|
||||
raise ValueError(
|
||||
"Image generation tool requires an OpenAI API key"
|
||||
)
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider=openai_provider.provider,
|
||||
model_name=openai_provider.default_model_name,
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=openai_provider.api_key,
|
||||
api_base=openai_provider.api_base,
|
||||
api_version=openai_provider.api_version,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
ImageGenerationTool(
|
||||
api_key=cast(str, img_generation_llm_config.api_key),
|
||||
api_base=img_generation_llm_config.api_base,
|
||||
api_version=img_generation_llm_config.api_version,
|
||||
additional_headers=litellm_additional_headers,
|
||||
)
|
||||
]
|
||||
elif tool_cls.__name__ == InternetSearchTool.__name__:
|
||||
bing_api_key = BING_API_KEY
|
||||
if not bing_api_key:
|
||||
raise ValueError(
|
||||
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
InternetSearchTool(api_key=bing_api_key)
|
||||
]
|
||||
|
||||
continue
|
||||
|
||||
# handle all custom tools
|
||||
if db_tool_model.openapi_schema:
|
||||
tool_dict[db_tool_model.id] = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema_and_headers(
|
||||
db_tool_model.openapi_schema,
|
||||
dynamic_schema_info=DynamicSchemaInfo(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
),
|
||||
custom_headers=db_tool_model.custom_headers,
|
||||
),
|
||||
)
|
||||
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
|
||||
tools, llm_tokenizer
|
||||
)
|
||||
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
|
||||
llm_provider, llm_model_name
|
||||
)
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
answer = Answer(
|
||||
is_connected=is_connected,
|
||||
question=final_msg.message,
|
||||
latest_query_files=latest_query_files,
|
||||
answer_style_config=answer_style_config,
|
||||
answer_style_config=AnswerStyleConfig(
|
||||
citation_config=CitationConfig(
|
||||
all_docs_useful=selected_db_search_docs is not None
|
||||
),
|
||||
document_pruning_config=document_pruning_config,
|
||||
),
|
||||
prompt_config=prompt_config,
|
||||
llm=(
|
||||
llm
|
||||
@@ -675,6 +709,7 @@ def stream_chat_message_objects(
|
||||
yield LLMRelevanceFilterResponse(
|
||||
llm_selected_doc_indices=llm_indices
|
||||
)
|
||||
|
||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
yield FinalUsedContextDocsResponse(
|
||||
final_context_docs=packet.response
|
||||
@@ -692,7 +727,7 @@ def stream_chat_message_objects(
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
for file_id in file_ids
|
||||
]
|
||||
yield FileChatDisplay(
|
||||
yield ImageGenerationDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
@@ -706,32 +741,11 @@ def stream_chat_message_objects(
|
||||
yield qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
|
||||
if (
|
||||
custom_tool_response.response_type == "image"
|
||||
or custom_tool_response.response_type == "csv"
|
||||
):
|
||||
file_ids = custom_tool_response.tool_result.file_ids
|
||||
ai_message_files = [
|
||||
FileDescriptor(
|
||||
id=str(file_id),
|
||||
type=ChatFileType.IMAGE
|
||||
if custom_tool_response.response_type == "image"
|
||||
else ChatFileType.CSV,
|
||||
)
|
||||
for file_id in file_ids
|
||||
]
|
||||
yield FileChatDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
else:
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
|
||||
elif isinstance(packet, StreamStopInfo):
|
||||
pass
|
||||
else:
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
tool_result = packet
|
||||
@@ -761,7 +775,6 @@ def stream_chat_message_objects(
|
||||
|
||||
# Post-LLM answer processing
|
||||
try:
|
||||
logger.debug("Post-LLM answer processing")
|
||||
message_specific_citations: MessageSpecificCitations | None = None
|
||||
if reference_db_search_docs:
|
||||
message_specific_citations = _translate_citations(
|
||||
@@ -777,6 +790,7 @@ def stream_chat_message_objects(
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=(
|
||||
qa_docs_response.rephrased_query if qa_docs_response else None
|
||||
@@ -784,21 +798,21 @@ def stream_chat_message_objects(
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=(
|
||||
message_specific_citations.citation_map
|
||||
if message_specific_citations
|
||||
else None
|
||||
),
|
||||
citations=message_specific_citations.citation_map
|
||||
if message_specific_citations
|
||||
else None,
|
||||
error=None,
|
||||
tool_call=(
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
tool_calls=(
|
||||
[
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
]
|
||||
if tool_result
|
||||
else None
|
||||
else []
|
||||
),
|
||||
)
|
||||
|
||||
@@ -822,8 +836,8 @@ def stream_chat_message_objects(
|
||||
def stream_chat_message(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
) -> Iterator[str]:
|
||||
with get_session_context_manager() as db_session:
|
||||
@@ -831,8 +845,8 @@ def stream_chat_message(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
custom_tool_additional_headers=custom_tool_additional_headers,
|
||||
is_connected=is_connected,
|
||||
)
|
||||
for obj in objects:
|
||||
|
||||
@@ -9,19 +9,19 @@ prompts:
|
||||
system: >
|
||||
You are a question answering system that is constantly learning and improving.
|
||||
The current date is DANSWER_DATETIME_REPLACEMENT.
|
||||
|
||||
|
||||
You can process and comprehend vast amounts of text and utilize this knowledge to provide
|
||||
grounded, accurate, and concise answers to diverse queries.
|
||||
|
||||
|
||||
You always clearly communicate ANY UNCERTAINTY in your answer.
|
||||
# Task Prompt (as shown in UI)
|
||||
task: >
|
||||
Answer my query based on the documents provided.
|
||||
The documents may not all be relevant, ignore any documents that are not directly relevant
|
||||
to the most recent user query.
|
||||
|
||||
|
||||
I have not read or seen any of the documents and do not want to read them.
|
||||
|
||||
|
||||
If there are no relevant documents, refer to the chat history and your internal knowledge.
|
||||
# Inject a statement at the end of system prompt to inform the LLM of the current date/time
|
||||
# If the DANSWER_DATETIME_REPLACEMENT is set, the date/time is inserted there instead
|
||||
@@ -30,21 +30,21 @@ prompts:
|
||||
# Prompts the LLM to include citations in the for [1], [2] etc.
|
||||
# which get parsed to match the passed in sources
|
||||
include_citations: true
|
||||
|
||||
|
||||
- name: "ImageGeneration"
|
||||
description: "Generates images from user descriptions!"
|
||||
description: "Generates images based on user prompts!"
|
||||
system: >
|
||||
You are an AI image generation assistant. Your role is to create high-quality images based on user descriptions.
|
||||
|
||||
For appropriate requests, you will generate an image that matches the user's requirements.
|
||||
For inappropriate or unsafe requests, you will politely decline and explain why the request cannot be fulfilled.
|
||||
|
||||
You aim to be helpful while maintaining appropriate content standards.
|
||||
You are an advanced image generation system capable of creating diverse and detailed images.
|
||||
|
||||
You can interpret user prompts and generate high-quality, creative images that match their descriptions.
|
||||
|
||||
You always strive to create safe and appropriate content, avoiding any harmful or offensive imagery.
|
||||
task: >
|
||||
Based on the user's description, create a high-quality image that accurately reflects their request.
|
||||
Pay close attention to the specified details, styles, and desired elements.
|
||||
|
||||
If the request is not appropriate or cannot be fulfilled, explain why and suggest alternatives.
|
||||
Generate an image based on the user's description.
|
||||
|
||||
Provide a detailed description of the generated image, including key elements, colors, and composition.
|
||||
|
||||
If the request is not possible or appropriate, explain why and suggest alternatives.
|
||||
datetime_aware: true
|
||||
include_citations: false
|
||||
|
||||
@@ -64,13 +64,14 @@ prompts:
|
||||
datetime_aware: true
|
||||
include_citations: true
|
||||
|
||||
|
||||
- name: "Summarize"
|
||||
description: "Summarize relevant information from retrieved context!"
|
||||
system: >
|
||||
You are a text summarizing assistant that highlights the most important knowledge from the
|
||||
context provided, prioritizing the information that relates to the user query.
|
||||
The current date is DANSWER_DATETIME_REPLACEMENT.
|
||||
|
||||
|
||||
You ARE NOT creative and always stick to the provided documents.
|
||||
If there are no documents, refer to the conversation history.
|
||||
|
||||
@@ -83,6 +84,7 @@ prompts:
|
||||
datetime_aware: true
|
||||
include_citations: true
|
||||
|
||||
|
||||
- name: "Paraphrase"
|
||||
description: "Recites information from retrieved context! Least creative but most safe!"
|
||||
system: >
|
||||
@@ -90,10 +92,10 @@ prompts:
|
||||
The current date is DANSWER_DATETIME_REPLACEMENT.
|
||||
|
||||
You only provide quotes that are EXACT substrings from provided documents!
|
||||
|
||||
|
||||
If there are no documents provided,
|
||||
simply tell the user that there are no documents to reference.
|
||||
|
||||
|
||||
You NEVER generate new text or phrases outside of the citation.
|
||||
DO NOT explain your responses, only provide the quotes and NOTHING ELSE.
|
||||
task: >
|
||||
|
||||
@@ -43,9 +43,6 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
|
||||
AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower())
|
||||
DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
|
||||
|
||||
# Necessary for cloud integration tests
|
||||
DISABLE_VERIFICATION = os.environ.get("DISABLE_VERIFICATION", "").lower() == "true"
|
||||
|
||||
# Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive
|
||||
# information. This provides an extra layer of security on top of Postgres access controls
|
||||
# and is available in Danswer EE
|
||||
@@ -56,6 +53,7 @@ MASK_CREDENTIAL_PREFIX = (
|
||||
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
|
||||
)
|
||||
|
||||
|
||||
SESSION_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
@@ -118,22 +116,17 @@ VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
VESPA_CONFIG_SERVER_HOST = os.environ.get("VESPA_CONFIG_SERVER_HOST") or VESPA_HOST
|
||||
VESPA_PORT = os.environ.get("VESPA_PORT") or "8081"
|
||||
VESPA_TENANT_PORT = os.environ.get("VESPA_TENANT_PORT") or "19071"
|
||||
|
||||
VESPA_CLOUD_URL = os.environ.get("VESPA_CLOUD_URL", "")
|
||||
|
||||
# The default below is for dockerized deployment
|
||||
VESPA_DEPLOYMENT_ZIP = (
|
||||
os.environ.get("VESPA_DEPLOYMENT_ZIP") or "/app/danswer/vespa-app.zip"
|
||||
)
|
||||
VESPA_CLOUD_CERT_PATH = os.environ.get("VESPA_CLOUD_CERT_PATH")
|
||||
VESPA_CLOUD_KEY_PATH = os.environ.get("VESPA_CLOUD_KEY_PATH")
|
||||
|
||||
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
|
||||
try:
|
||||
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE", 16))
|
||||
except ValueError:
|
||||
INDEX_BATCH_SIZE = 16
|
||||
|
||||
|
||||
# Below are intended to match the env variables names used by the official postgres docker image
|
||||
# https://hub.docker.com/_/postgres
|
||||
POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres"
|
||||
@@ -163,17 +156,6 @@ try:
|
||||
except ValueError:
|
||||
POSTGRES_POOL_RECYCLE = POSTGRES_POOL_RECYCLE_DEFAULT
|
||||
|
||||
# Experimental setting to control idle transactions
|
||||
POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT = 0 # milliseconds
|
||||
try:
|
||||
POSTGRES_IDLE_SESSIONS_TIMEOUT = int(
|
||||
os.environ.get(
|
||||
"POSTGRES_IDLE_SESSIONS_TIMEOUT", POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
POSTGRES_IDLE_SESSIONS_TIMEOUT = POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT
|
||||
|
||||
REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true"
|
||||
REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
|
||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
||||
@@ -211,41 +193,6 @@ try:
|
||||
except ValueError:
|
||||
CELERY_BROKER_POOL_LIMIT = CELERY_BROKER_POOL_LIMIT_DEFAULT
|
||||
|
||||
CELERY_WORKER_LIGHT_CONCURRENCY_DEFAULT = 24
|
||||
try:
|
||||
CELERY_WORKER_LIGHT_CONCURRENCY = int(
|
||||
os.environ.get(
|
||||
"CELERY_WORKER_LIGHT_CONCURRENCY", CELERY_WORKER_LIGHT_CONCURRENCY_DEFAULT
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
CELERY_WORKER_LIGHT_CONCURRENCY = CELERY_WORKER_LIGHT_CONCURRENCY_DEFAULT
|
||||
|
||||
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT = 8
|
||||
try:
|
||||
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER = int(
|
||||
os.environ.get(
|
||||
"CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER",
|
||||
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT,
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER = (
|
||||
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT
|
||||
)
|
||||
|
||||
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 1
|
||||
try:
|
||||
env_value = os.environ.get("CELERY_WORKER_INDEXING_CONCURRENCY")
|
||||
if not env_value:
|
||||
env_value = os.environ.get("NUM_INDEXING_WORKERS")
|
||||
|
||||
if not env_value:
|
||||
env_value = str(CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT)
|
||||
CELERY_WORKER_INDEXING_CONCURRENCY = int(env_value)
|
||||
except ValueError:
|
||||
CELERY_WORKER_INDEXING_CONCURRENCY = CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT
|
||||
|
||||
#####
|
||||
# Connector Configs
|
||||
#####
|
||||
@@ -262,6 +209,9 @@ ENABLED_CONNECTOR_TYPES = os.environ.get("ENABLED_CONNECTOR_TYPES") or ""
|
||||
# for some connectors
|
||||
ENABLE_EXPENSIVE_EXPERT_CALLS = False
|
||||
|
||||
GOOGLE_DRIVE_INCLUDE_SHARED = False
|
||||
GOOGLE_DRIVE_FOLLOW_SHORTCUTS = False
|
||||
GOOGLE_DRIVE_ONLY_ORG_PUBLIC = False
|
||||
|
||||
# TODO these should be available for frontend configuration, via advanced options expandable
|
||||
WEB_CONNECTOR_IGNORED_CLASSES = os.environ.get(
|
||||
@@ -298,6 +248,12 @@ CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES = (
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Save pages labels as Danswer metadata tags
|
||||
# The reason to skip this would be to reduce the number of calls to Confluence due to rate limit concerns
|
||||
CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING = (
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Attachments exceeding this size will not be retrieved (in bytes)
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD = int(
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||
@@ -445,11 +401,6 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
|
||||
os.environ.get("CUSTOM_ANSWER_VALIDITY_CONDITIONS", "[]")
|
||||
)
|
||||
|
||||
VESPA_REQUEST_TIMEOUT = int(os.environ.get("VESPA_REQUEST_TIMEOUT") or "15")
|
||||
|
||||
SYSTEM_RECURSION_LIMIT = int(os.environ.get("SYSTEM_RECURSION_LIMIT") or "1000")
|
||||
|
||||
PARSE_WITH_TRAFILATURA = os.environ.get("PARSE_WITH_TRAFILATURA", "").lower() == "true"
|
||||
|
||||
#####
|
||||
# Enterprise Edition Configs
|
||||
@@ -462,48 +413,10 @@ ENTERPRISE_EDITION_ENABLED = (
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Azure DALL-E Configurations
|
||||
AZURE_DALLE_API_VERSION = os.environ.get("AZURE_DALLE_API_VERSION")
|
||||
AZURE_DALLE_API_KEY = os.environ.get("AZURE_DALLE_API_KEY")
|
||||
AZURE_DALLE_API_BASE = os.environ.get("AZURE_DALLE_API_BASE")
|
||||
AZURE_DALLE_DEPLOYMENT_NAME = os.environ.get("AZURE_DALLE_DEPLOYMENT_NAME")
|
||||
|
||||
MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
|
||||
SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "")
|
||||
|
||||
|
||||
# Use managed Vespa (Vespa Cloud). If set, must also set VESPA_CLOUD_URL, VESPA_CLOUD_CERT_PATH and VESPA_CLOUD_KEY_PATH
|
||||
MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true"
|
||||
|
||||
ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true"
|
||||
|
||||
# Security and authentication
|
||||
DATA_PLANE_SECRET = os.environ.get(
|
||||
"DATA_PLANE_SECRET", ""
|
||||
) # Used for secure communication between the control and data plane
|
||||
EXPECTED_API_KEY = os.environ.get(
|
||||
"EXPECTED_API_KEY", ""
|
||||
) # Additional security check for the control plane API
|
||||
|
||||
# API configuration
|
||||
CONTROL_PLANE_API_BASE_URL = os.environ.get(
|
||||
"CONTROL_PLANE_API_BASE_URL", "http://localhost:8082"
|
||||
)
|
||||
|
||||
# JWT configuration
|
||||
JWT_ALGORITHM = "HS256"
|
||||
|
||||
# Super Users
|
||||
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
|
||||
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
|
||||
|
||||
#####
|
||||
# API Key Configs
|
||||
#####
|
||||
# refers to the rounds described here: https://passlib.readthedocs.io/en/stable/lib/passlib.hash.sha256_crypt.html
|
||||
_API_KEY_HASH_ROUNDS_RAW = os.environ.get("API_KEY_HASH_ROUNDS")
|
||||
API_KEY_HASH_ROUNDS = (
|
||||
int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None
|
||||
)
|
||||
|
||||
|
||||
POD_NAME = os.environ.get("POD_NAME")
|
||||
POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
|
||||
DATA_PLANE_SECRET = os.environ.get("DATA_PLANE_SECRET", "")
|
||||
EXPECTED_API_KEY = os.environ.get("EXPECTED_API_KEY", "")
|
||||
|
||||
@@ -39,10 +39,9 @@ POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
|
||||
POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
|
||||
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
|
||||
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
||||
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
||||
POSTGRES_DEFAULT_SCHEMA = "public"
|
||||
|
||||
# API Keys
|
||||
DANSWER_API_KEY_PREFIX = "API_KEY__"
|
||||
@@ -67,25 +66,10 @@ KV_CUSTOMER_UUID_KEY = "customer_uuid"
|
||||
KV_INSTANCE_DOMAIN_KEY = "instance_domain"
|
||||
KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings"
|
||||
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
|
||||
KV_DOCUMENTS_SEEDED_KEY = "documents_seeded"
|
||||
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
|
||||
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
|
||||
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_INDEXING_LOCK_TIMEOUT = 60 * 60 # 60 min
|
||||
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
# Special case, document passed in via Danswer APIs without specifying a source type
|
||||
@@ -129,17 +113,10 @@ class DocumentSource(str, Enum):
|
||||
OCI_STORAGE = "oci_storage"
|
||||
XENFORO = "xenforo"
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
FRESHDESK = "freshdesk"
|
||||
FIREFLIES = "fireflies"
|
||||
|
||||
|
||||
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
|
||||
|
||||
|
||||
class NotificationType(str, Enum):
|
||||
REINDEX = "reindex"
|
||||
PERSONA_SHARED = "persona_shared"
|
||||
TRIAL_ENDS_TWO_DAYS = "two_day_trial_ending" # 2 days left in trial
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
@@ -164,9 +141,6 @@ class AuthType(str, Enum):
|
||||
OIDC = "oidc"
|
||||
SAML = "saml"
|
||||
|
||||
# google auth and basic
|
||||
CLOUD = "cloud"
|
||||
|
||||
|
||||
class SessionType(str, Enum):
|
||||
CHAT = "Chat"
|
||||
@@ -213,43 +187,17 @@ class PostgresAdvisoryLocks(Enum):
|
||||
|
||||
|
||||
class DanswerCeleryQueues:
|
||||
# Light queue
|
||||
VESPA_METADATA_SYNC = "vespa_metadata_sync"
|
||||
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
|
||||
CONNECTOR_DELETION = "connector_deletion"
|
||||
|
||||
# Heavy queue
|
||||
CONNECTOR_PRUNING = "connector_pruning"
|
||||
CONNECTOR_DOC_PERMISSIONS_SYNC = "connector_doc_permissions_sync"
|
||||
CONNECTOR_EXTERNAL_GROUP_SYNC = "connector_external_group_sync"
|
||||
|
||||
# Indexing queue
|
||||
CONNECTOR_INDEXING = "connector_indexing"
|
||||
|
||||
|
||||
class DanswerRedisLocks:
|
||||
PRIMARY_WORKER = "da_lock:primary_worker"
|
||||
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
|
||||
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
|
||||
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
|
||||
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
|
||||
CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat"
|
||||
CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK = (
|
||||
"da_lock:check_connector_doc_permissions_sync_beat"
|
||||
)
|
||||
CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK = (
|
||||
"da_lock:check_connector_external_group_sync_beat"
|
||||
)
|
||||
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
|
||||
|
||||
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
|
||||
"da_lock:connector_doc_permissions_sync"
|
||||
)
|
||||
CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX = "da_lock:connector_external_group_sync"
|
||||
PRUNING_LOCK_PREFIX = "da_lock:pruning"
|
||||
INDEXING_METADATA_PREFIX = "da_metadata:indexing"
|
||||
|
||||
SLACK_BOT_LOCK = "da_lock:slack_bot"
|
||||
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
|
||||
|
||||
|
||||
class DanswerCeleryPriority(int, Enum):
|
||||
|
||||
@@ -119,14 +119,3 @@ if _LITELLM_PASS_THROUGH_HEADERS_RAW:
|
||||
logger.error(
|
||||
"Failed to parse LITELLM_PASS_THROUGH_HEADERS, must be a valid JSON object"
|
||||
)
|
||||
|
||||
|
||||
# if specified, will merge the specified JSON with the existing body of the
|
||||
# request before sending it to the LLM
|
||||
LITELLM_EXTRA_BODY: dict | None = None
|
||||
_LITELLM_EXTRA_BODY_RAW = os.environ.get("LITELLM_EXTRA_BODY")
|
||||
if _LITELLM_EXTRA_BODY_RAW:
|
||||
try:
|
||||
LITELLM_EXTRA_BODY = json.loads(_LITELLM_EXTRA_BODY_RAW)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
# if specified, will pass through request headers to the call to API calls made by custom tools
|
||||
CUSTOM_TOOL_PASS_THROUGH_HEADERS: list[str] | None = None
|
||||
_CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get(
|
||||
"CUSTOM_TOOL_PASS_THROUGH_HEADERS"
|
||||
)
|
||||
if _CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW:
|
||||
try:
|
||||
CUSTOM_TOOL_PASS_THROUGH_HEADERS = json.loads(
|
||||
_CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW
|
||||
)
|
||||
except Exception:
|
||||
# need to import here to avoid circular imports
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
logger.error(
|
||||
"Failed to parse CUSTOM_TOOL_PASS_THROUGH_HEADERS, must be a valid JSON object"
|
||||
)
|
||||
@@ -13,8 +13,8 @@ Connectors come in 3 different flows:
|
||||
documents via a connector's API or loads the documents from some sort of a dump file.
|
||||
- Poll connector:
|
||||
- Incrementally updates documents based on a provided time range. It is used by the background job to pull the latest
|
||||
changes and additions since the last round of polling. This connector helps keep the document index up to date
|
||||
without needing to fetch/embed/index every document which would be too slow to do frequently on large sets of
|
||||
changes additions and changes since the last round of polling. This connector helps keep the document index up to date
|
||||
without needing to fetch/embed/index every document which generally be too slow to do frequently on large sets of
|
||||
documents.
|
||||
- Event Based connectors:
|
||||
- Connectors that listen to events and update documents accordingly.
|
||||
|
||||
@@ -15,6 +15,7 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_st
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
@@ -23,7 +24,6 @@ from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.html_utils import parse_html_page_basic
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -44,6 +44,8 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> tuple[list[Document], int]:
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
params = {
|
||||
"count": str(batch_size),
|
||||
"offset": str(start_ind),
|
||||
@@ -61,7 +63,8 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
|
||||
batch = bookstack_client.get(endpoint, params=params).get("data", [])
|
||||
doc_batch = [transformer(bookstack_client, item) for item in batch]
|
||||
for item in batch:
|
||||
doc_batch.append(transformer(bookstack_client, item))
|
||||
|
||||
return doc_batch, len(batch)
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
@@ -18,7 +19,6 @@ from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
CLICKUP_API_BASE_URL = "https://api.clickup.com/api/v2"
|
||||
@@ -210,7 +210,6 @@ if __name__ == "__main__":
|
||||
"clickup_team_id": os.environ["clickup_team_id"],
|
||||
}
|
||||
)
|
||||
|
||||
latest_docs = clickup_connector.load_from_state()
|
||||
|
||||
for doc in latest_docs:
|
||||
|
||||
32
backend/danswer/connectors/confluence/confluence_utils.py
Normal file
32
backend/danswer/connectors/confluence/confluence_utils.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import bs4
|
||||
|
||||
|
||||
def build_confluence_document_id(base_url: str, content_url: str) -> str:
|
||||
"""For confluence, the document id is the page url for a page based document
|
||||
or the attachment download url for an attachment based document
|
||||
|
||||
Args:
|
||||
base_url (str): The base url of the Confluence instance
|
||||
content_url (str): The url of the page or attachment download url
|
||||
|
||||
Returns:
|
||||
str: The document id
|
||||
"""
|
||||
return f"{base_url}{content_url}"
|
||||
|
||||
|
||||
def get_used_attachments(text: str) -> list[str]:
|
||||
"""Parse a Confluence html page to generate a list of current
|
||||
attachment in used
|
||||
|
||||
Args:
|
||||
text (str): The page content
|
||||
|
||||
Returns:
|
||||
list[str]: List of filenames currently in use by the page text
|
||||
"""
|
||||
files_in_used = []
|
||||
soup = bs4.BeautifulSoup(text, "html.parser")
|
||||
for attachment in soup.findAll("ri:attachment"):
|
||||
files_in_used.append(attachment.attrs["ri:filename"])
|
||||
return files_in_used
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user