mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
1 Commits
borders
...
remove_emp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
81e1ac9183 |
@@ -3,61 +3,61 @@ name: Build and Push Backend Image on Tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
- '*'
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'danswer/danswer-backend-cloud' || 'danswer/danswer-backend' }}
|
||||
REGISTRY_IMAGE: danswer/danswer-backend
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
# TODO: investigate a matrix build like the web container
|
||||
# TODO: investigate a matrix build like the web container
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Install build-essential
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential
|
||||
- name: Install build-essential
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential
|
||||
|
||||
- name: Backend Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
- name: Backend Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
with:
|
||||
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
trivyignores: ./backend/.trivyignore
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
trivyignores: ./backend/.trivyignore
|
||||
|
||||
@@ -4,12 +4,12 @@ name: Build and Push Cloud Web Image on Tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
- '*'
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-web-server-cloud
|
||||
REGISTRY_IMAGE: danswer/danswer-cloud-web-server
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on:
|
||||
@@ -28,11 +28,11 @@ jobs:
|
||||
- name: Prepare
|
||||
run: |
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
@@ -41,16 +41,16 @@ jobs:
|
||||
tags: |
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
|
||||
- name: Build and push by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
@@ -65,18 +65,17 @@ jobs:
|
||||
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
mkdir -p /tmp/digests
|
||||
digest="${{ steps.build.outputs.digest }}"
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
@@ -96,42 +95,42 @@ jobs:
|
||||
path: /tmp/digests
|
||||
pattern: digests-*
|
||||
merge-multiple: true
|
||||
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
|
||||
- name: Create manifest list and push
|
||||
working-directory: /tmp/digests
|
||||
run: |
|
||||
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
|
||||
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
|
||||
|
||||
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
|
||||
|
||||
- name: Inspect image
|
||||
run: |
|
||||
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
severity: 'CRITICAL,HIGH'
|
||||
|
||||
@@ -3,53 +3,53 @@ name: Build and Push Model Server Image on Tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
- '*'
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'danswer/danswer-model-server-cloud' || 'danswer/danswer-model-server' }}
|
||||
REGISTRY_IMAGE: danswer/danswer-model-server
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Model Server Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
- name: Model Server Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
with:
|
||||
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
|
||||
76
.github/workflows/nightly-scan-licenses.yml
vendored
76
.github/workflows/nightly-scan-licenses.yml
vendored
@@ -1,76 +0,0 @@
|
||||
# Scan for problematic software licenses
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
|
||||
name: 'Nightly - Scan licenses'
|
||||
on:
|
||||
# schedule:
|
||||
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
security-events: write
|
||||
|
||||
jobs:
|
||||
scan-licenses:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
|
||||
- name: Get explicit and transitive dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
pip freeze > requirements-all.txt
|
||||
|
||||
- name: Check python
|
||||
id: license_check_report
|
||||
uses: pilosus/action-pip-license-checker@v2
|
||||
with:
|
||||
requirements: 'requirements-all.txt'
|
||||
fail: 'Copyleft'
|
||||
exclude: '(?i)^(pylint|aio[-_]*).*'
|
||||
|
||||
- name: Print report
|
||||
if: ${{ always() }}
|
||||
run: echo "${{ steps.license_check_report.outputs.report }}"
|
||||
|
||||
- name: Install npm dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Run Trivy vulnerability scanner in repo mode
|
||||
uses: aquasecurity/trivy-action@0.28.0
|
||||
with:
|
||||
scan-type: fs
|
||||
scanners: license
|
||||
format: table
|
||||
# format: sarif
|
||||
# output: trivy-results.sarif
|
||||
severity: HIGH,CRITICAL
|
||||
|
||||
# - name: Upload Trivy scan results to GitHub Security tab
|
||||
# uses: github/codeql-action/upload-sarif@v3
|
||||
# with:
|
||||
# sarif_file: trivy-results.sarif
|
||||
@@ -13,10 +13,7 @@ on:
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
|
||||
|
||||
jobs:
|
||||
integration-tests:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
@@ -198,13 +195,9 @@ jobs:
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
danswer/danswer-integration:test \
|
||||
/app/tests/integration/tests \
|
||||
/app/tests/integration/connector_job_tests
|
||||
/app/tests/integration/tests
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
@@ -217,18 +210,17 @@ jobs:
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
# save before stopping the containers so the logs can be captured
|
||||
- name: Stop Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
|
||||
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Stop Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
225
.github/workflows/pr-chromatic-tests.yml
vendored
225
.github/workflows/pr-chromatic-tests.yml
vendored
@@ -1,225 +0,0 @@
|
||||
name: Run Chromatic Tests
|
||||
concurrency:
|
||||
group: Run-Chromatic-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on: push
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
|
||||
jobs:
|
||||
playwright-tests:
|
||||
name: Playwright Tests
|
||||
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Install playwright browsers
|
||||
working-directory: ./web
|
||||
run: npx playwright install --with-deps
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
# we use the runs-on cache for docker builds
|
||||
# in conjunction with runs-on runners, it has better speed and unlimited caching
|
||||
# https://runs-on.com/caching/s3-cache-for-github-actions/
|
||||
# https://runs-on.com/caching/docker/
|
||||
# https://github.com/moby/buildkit#s3-cache-experimental
|
||||
|
||||
# images are built and run locally for testing purposes. Not pushed.
|
||||
|
||||
- name: Build Web Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: danswer/danswer-web-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Backend Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: danswer/danswer-backend:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64
|
||||
tags: danswer/danswer-model-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
docker logs -f danswer-stack-api_server-1 &
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Use curl with error handling to ignore specific exit code 56
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Run pytest playwright test init
|
||||
working-directory: ./backend
|
||||
env:
|
||||
PYTEST_IGNORE_SKIP: true
|
||||
run: pytest -s tests/integration/tests/playwright/test_playwright.py
|
||||
|
||||
- name: Run Playwright tests
|
||||
working-directory: ./web
|
||||
run: npx playwright test
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
# Chromatic automatically defaults to the test-results directory.
|
||||
# Replace with the path to your custom directory and adjust the CHROMATIC_ARCHIVE_LOCATION environment variable accordingly.
|
||||
name: test-results
|
||||
path: ./web/test-results
|
||||
retention-days: 30
|
||||
|
||||
# save before stopping the containers so the logs can be captured
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
|
||||
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Stop Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
chromatic-tests:
|
||||
name: Chromatic Tests
|
||||
|
||||
needs: playwright-tests
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Download Playwright test results
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: test-results
|
||||
path: ./web/test-results
|
||||
|
||||
- name: Run Chromatic
|
||||
uses: chromaui/action@latest
|
||||
with:
|
||||
playwright: true
|
||||
projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
|
||||
workingDir: ./web
|
||||
env:
|
||||
CHROMATIC_ARCHIVE_LOCATION: ./test-results
|
||||
72
.github/workflows/pr-helm-chart-testing.yml
vendored
72
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -1,72 +0,0 @@
|
||||
name: Helm - Lint and Test Charts
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
jobs:
|
||||
helm-chart-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
|
||||
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4.2.0
|
||||
with:
|
||||
version: v3.14.4
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
|
||||
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
run: |
|
||||
echo "default_branch: ${{ github.event.repository.default_branch }}"
|
||||
changed=$(ct list-changed --remote origin --target-branch ${{ github.event.repository.default_branch }} --chart-dirs deployment/helm/charts)
|
||||
echo "list-changed output: $changed"
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
# rkuo: I don't think we need python?
|
||||
# - name: Set up Python
|
||||
# uses: actions/setup-python@v5
|
||||
# with:
|
||||
# python-version: '3.11'
|
||||
# cache: 'pip'
|
||||
# cache-dependency-path: |
|
||||
# backend/requirements/default.txt
|
||||
# backend/requirements/dev.txt
|
||||
# backend/requirements/model_server.txt
|
||||
# - run: |
|
||||
# python -m pip install --upgrade pip
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --config ct.yaml --all
|
||||
# the following would lint only changed charts, but linting isn't expensive
|
||||
# run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.10.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct install --all --helm-extra-set-args="--set=nginx.enabled=false" --debug --config ct.yaml
|
||||
# the following would install only changed charts, but we only have one chart so
|
||||
# don't worry about that for now
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
68
.github/workflows/pr-helm-chart-testing.yml.disabled.txt
vendored
Normal file
68
.github/workflows/pr-helm-chart-testing.yml.disabled.txt
vendored
Normal file
@@ -0,0 +1,68 @@
|
||||
# This workflow is intentionally disabled while we're still working on it
|
||||
# It's close to ready, but a race condition needs to be fixed with
|
||||
# API server and Vespa startup, and it needs to have a way to build/test against
|
||||
# local containers
|
||||
|
||||
name: Helm - Lint and Test Charts
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
lint-test:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
|
||||
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4.2.0
|
||||
with:
|
||||
version: v3.14.4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
run: |
|
||||
changed=$(ct list-changed --target-branch ${{ github.event.repository.default_branch }})
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Run chart-testing (lint)
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --all --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
- name: Create kind cluster
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.10.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct install --all --config ct.yaml
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
@@ -18,14 +18,6 @@ env:
|
||||
# Jira
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
# Google
|
||||
GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR }}
|
||||
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1 }}
|
||||
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
|
||||
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
# Slab
|
||||
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
|
||||
2
.github/workflows/pr-python-model-tests.yml
vendored
2
.github/workflows/pr-python-model-tests.yml
vendored
@@ -15,7 +15,7 @@ env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
jobs:
|
||||
model-check:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -7,4 +7,3 @@
|
||||
.vscode/
|
||||
*.sw?
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
/web/test-results/
|
||||
4
.vscode/launch.template.jsonc
vendored
4
.vscode/launch.template.jsonc
vendored
@@ -203,7 +203,7 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
|
||||
"vespa_metadata_sync,connector_deletion",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
@@ -232,7 +232,7 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
|
||||
"connector_pruning",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
|
||||
@@ -32,7 +32,7 @@ To contribute to this project, please follow the
|
||||
When opening a pull request, mention related issues and feel free to tag relevant maintainers.
|
||||
|
||||
Before creating a pull request please make sure that the new changes conform to the formatting and linting requirements.
|
||||
See the [Formatting and Linting](#formatting-and-linting) section for how to run these checks locally.
|
||||
See the [Formatting and Linting](#-formatting-and-linting) section for how to run these checks locally.
|
||||
|
||||
|
||||
### Getting Help 🙋
|
||||
|
||||
71
README.md
71
README.md
@@ -1,48 +1,47 @@
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/README.md"} -->
|
||||
<a name="readme-top"></a>
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/README.md"} -->
|
||||
|
||||
<h2 align="center">
|
||||
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/LogoOnyx.png?raw=true)" /></a>
|
||||
<a href="https://www.danswer.ai/"> <img width="50%" src="https://github.com/danswer-owners/danswer/blob/1fabd9372d66cd54238847197c33f091a724803b/DanswerWithName.png?raw=true)" /></a>
|
||||
</h2>
|
||||
|
||||
<p align="center">
|
||||
<p align="center">Open Source Gen-AI + Enterprise Search.</p>
|
||||
<p align="center">Open Source Gen-AI Chat + Unified Search.</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://docs.onyx.app/" target="_blank">
|
||||
<a href="https://docs.danswer.dev/" target="_blank">
|
||||
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
|
||||
</a>
|
||||
<a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2sslpdbyq-iIbTaNIVPBw_i_4vrujLYQ" target="_blank">
|
||||
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ" target="_blank">
|
||||
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
|
||||
</a>
|
||||
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
|
||||
<img src="https://img.shields.io/badge/discord-join-blue.svg?logo=discord&logoColor=white" alt="Discord">
|
||||
</a>
|
||||
<a href="https://github.com/onyx-dot-app/onyx/blob/main/README.md" target="_blank">
|
||||
<a href="https://github.com/danswer-ai/danswer/blob/main/README.md" target="_blank">
|
||||
<img src="https://img.shields.io/static/v1?label=license&message=MIT&color=blue" alt="License">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (Formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
|
||||
Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any
|
||||
<strong>[Danswer](https://www.danswer.ai/)</strong> is the AI Assistant connected to your company's docs, apps, and people.
|
||||
Danswer provides a Chat interface and plugs into any LLM of your choice. Danswer can be deployed anywhere and for any
|
||||
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
|
||||
own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready
|
||||
own control. Danswer is MIT licensed and designed to be modular and easily extensible. The system also comes fully ready
|
||||
for production usage with user authentication, role management (admin/basic users), chat persistence, and a UI for
|
||||
configuring AI Assistants.
|
||||
configuring Personas (AI Assistants) and their Prompts.
|
||||
|
||||
Onyx also serves as a Enterprise Search across all common workplace tools such as Slack, Google Drive, Confluence, etc.
|
||||
By combining LLMs and team specific knowledge, Onyx becomes a subject matter expert for the team. Imagine ChatGPT if
|
||||
Danswer also serves as a Unified Search across all common workplace tools such as Slack, Google Drive, Confluence, etc.
|
||||
By combining LLMs and team specific knowledge, Danswer becomes a subject matter expert for the team. Imagine ChatGPT if
|
||||
it had access to your team's unique knowledge! It enables questions such as "A customer wants feature X, is this already
|
||||
supported?" or "Where's the pull request for feature Y?"
|
||||
|
||||
<h3>Usage</h3>
|
||||
|
||||
Onyx Web App:
|
||||
Danswer Web App:
|
||||
|
||||
https://github.com/danswer-ai/danswer/assets/32520769/563be14c-9304-47b5-bf0a-9049c2b6f410
|
||||
|
||||
|
||||
Or, plug Onyx into your existing Slack workflows (more integrations to come 😁):
|
||||
Or, plug Danswer into your existing Slack workflows (more integrations to come 😁):
|
||||
|
||||
https://github.com/danswer-ai/danswer/assets/25087905/3e19739b-d178-4371-9a38-011430bdec1b
|
||||
|
||||
@@ -52,16 +51,16 @@ For more details on the Admin UI to manage connectors and users, check out our
|
||||
|
||||
## Deployment
|
||||
|
||||
Onyx can easily be run locally (even on a laptop) or deployed on a virtual machine with a single
|
||||
`docker compose` command. Checkout our [docs](https://docs.onyx.app/quickstart) to learn more.
|
||||
Danswer can easily be run locally (even on a laptop) or deployed on a virtual machine with a single
|
||||
`docker compose` command. Checkout our [docs](https://docs.danswer.dev/quickstart) to learn more.
|
||||
|
||||
We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment/kubernetes).
|
||||
We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/danswer-ai/danswer/tree/main/deployment/kubernetes).
|
||||
|
||||
|
||||
## 💃 Main Features
|
||||
* Chat UI with the ability to select documents to chat with.
|
||||
* Create custom AI Assistants with different prompts and backing knowledge sets.
|
||||
* Connect Onyx with LLM of your choice (self-host for a fully airgapped solution).
|
||||
* Connect Danswer with LLM of your choice (self-host for a fully airgapped solution).
|
||||
* Document Search + AI Answers for natural language queries.
|
||||
* Connectors to all common workplace tools like Google Drive, Confluence, Slack, etc.
|
||||
* Slack integration to get answers and search results directly in Slack.
|
||||
@@ -75,12 +74,12 @@ We also have built-in support for deployment on Kubernetes. Files for that can b
|
||||
* Organizational understanding and ability to locate and suggest experts from your team.
|
||||
|
||||
|
||||
## Other Notable Benefits of Onyx
|
||||
## Other Notable Benefits of Danswer
|
||||
* User Authentication with document level access management.
|
||||
* Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
|
||||
* Admin Dashboard to configure connectors, document-sets, access, etc.
|
||||
* Custom deep learning models + learn from user feedback.
|
||||
* Easy deployment and ability to host Onyx anywhere of your choosing.
|
||||
* Easy deployment and ability to host Danswer anywhere of your choosing.
|
||||
|
||||
|
||||
## 🔌 Connectors
|
||||
@@ -108,10 +107,10 @@ Efficiently pulls the latest changes from:
|
||||
|
||||
## 📚 Editions
|
||||
|
||||
There are two editions of Onyx:
|
||||
There are two editions of Danswer:
|
||||
|
||||
* Onyx Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Onyx you will get if you follow the Deployment guide above.
|
||||
* Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
|
||||
* Danswer Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Danswer you will get if you follow the Deployment guide above.
|
||||
* Danswer Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
|
||||
* Single Sign-On (SSO), with support for both SAML and OIDC
|
||||
* Role-based access control
|
||||
* Document permission inheritance from connected sources
|
||||
@@ -119,28 +118,12 @@ There are two editions of Onyx:
|
||||
* Whitelabeling
|
||||
* API key authentication
|
||||
* Encryption of secrets
|
||||
* Any many more! Checkout [our website](https://www.onyx.app/) for the latest.
|
||||
* Any many more! Checkout [our website](https://www.danswer.ai/) for the latest.
|
||||
|
||||
To try the Onyx Enterprise Edition:
|
||||
To try the Danswer Enterprise Edition:
|
||||
|
||||
1. Checkout our [Cloud product](https://cloud.onyx.app/signup).
|
||||
2. For self-hosting, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
|
||||
1. Checkout our [Cloud product](https://app.danswer.ai/signup).
|
||||
2. For self-hosting, contact us at [founders@danswer.ai](mailto:founders@danswer.ai) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
|
||||
|
||||
## 💡 Contributing
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
||||
## ⭐Star History
|
||||
|
||||
[](https://star-history.com/#onyx-dot-app/onyx&Date)
|
||||
|
||||
## ✨Contributors
|
||||
|
||||
<a href="https://github.com/onyx-dot-app/onyx/graphs/contributors">
|
||||
<img alt="contributors" src="https://contrib.rocks/image?repo=onyx-dot-app/onyx"/>
|
||||
</a>
|
||||
|
||||
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
|
||||
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
|
||||
↑ Back to Top ↑
|
||||
</a>
|
||||
</p>
|
||||
|
||||
@@ -12,6 +12,7 @@ ARG DANSWER_VERSION=0.8-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
ARG CA_CERT_CONTENT=""
|
||||
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
# Install system dependencies
|
||||
@@ -38,6 +39,15 @@ RUN apt-get update && \
|
||||
apt-get clean
|
||||
|
||||
|
||||
# Conditionally write the CA certificate and update certificates
|
||||
RUN if [ -n "$CA_CERT_CONTENT" ]; then \
|
||||
echo "Adding custom CA certificate"; \
|
||||
echo "$CA_CERT_CONTENT" > /usr/local/share/ca-certificates/my-ca.crt && \
|
||||
chmod 644 /usr/local/share/ca-certificates/my-ca.crt && \
|
||||
update-ca-certificates; \
|
||||
else \
|
||||
echo "No custom CA certificate provided"; \
|
||||
fi
|
||||
|
||||
# Install Python dependencies
|
||||
# Remove py which is pulled in by retry, py is not needed and is a CVE
|
||||
@@ -73,11 +83,11 @@ RUN apt-get update && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
|
||||
|
||||
# Pre-downloading NLTK for setups with limited egress
|
||||
RUN python -c "import nltk; \
|
||||
nltk.download('stopwords', quiet=True); \
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from typing import Literal
|
||||
from typing import Any
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
import logging
|
||||
@@ -8,7 +8,6 @@ from alembic import context
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.sql import text
|
||||
from sqlalchemy.sql.schema import SchemaItem
|
||||
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from danswer.db.engine import build_connection_string
|
||||
@@ -36,18 +35,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem,
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
object: Any, name: str, type_: str, reflected: bool, compare_to: Any
|
||||
) -> bool:
|
||||
"""
|
||||
Determines whether a database object should be included in migrations.
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
"""display custom llm models
|
||||
|
||||
Revision ID: 177de57c21c9
|
||||
Revises: 4ee1287bd26a
|
||||
Create Date: 2024-11-21 11:49:04.488677
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy import and_
|
||||
|
||||
revision = "177de57c21c9"
|
||||
down_revision = "4ee1287bd26a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
llm_provider = sa.table(
|
||||
"llm_provider",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("provider", sa.String),
|
||||
sa.column("model_names", postgresql.ARRAY(sa.String)),
|
||||
sa.column("display_model_names", postgresql.ARRAY(sa.String)),
|
||||
)
|
||||
|
||||
excluded_providers = ["openai", "bedrock", "anthropic", "azure"]
|
||||
|
||||
providers_to_update = sa.select(
|
||||
llm_provider.c.id,
|
||||
llm_provider.c.model_names,
|
||||
llm_provider.c.display_model_names,
|
||||
).where(
|
||||
and_(
|
||||
~llm_provider.c.provider.in_(excluded_providers),
|
||||
llm_provider.c.model_names.isnot(None),
|
||||
)
|
||||
)
|
||||
|
||||
results = conn.execute(providers_to_update).fetchall()
|
||||
|
||||
for provider_id, model_names, display_model_names in results:
|
||||
if display_model_names is None:
|
||||
display_model_names = []
|
||||
|
||||
combined_model_names = list(set(display_model_names + model_names))
|
||||
update_stmt = (
|
||||
llm_provider.update()
|
||||
.where(llm_provider.c.id == provider_id)
|
||||
.values(display_model_names=combined_model_names)
|
||||
)
|
||||
conn.execute(update_stmt)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -1,68 +0,0 @@
|
||||
"""default chosen assistants to none
|
||||
|
||||
Revision ID: 26b931506ecb
|
||||
Revises: 2daa494a0851
|
||||
Create Date: 2024-11-12 13:23:29.858995
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "26b931506ecb"
|
||||
down_revision = "2daa494a0851"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user", sa.Column("chosen_assistants_new", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET chosen_assistants_new =
|
||||
CASE
|
||||
WHEN chosen_assistants = '[-2, -1, 0]' THEN NULL
|
||||
ELSE chosen_assistants
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
op.drop_column("user", "chosen_assistants")
|
||||
|
||||
op.alter_column(
|
||||
"user", "chosen_assistants_new", new_column_name="chosen_assistants"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"chosen_assistants_old",
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
server_default="[-2, -1, 0]",
|
||||
),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET chosen_assistants_old =
|
||||
CASE
|
||||
WHEN chosen_assistants IS NULL THEN '[-2, -1, 0]'::jsonb
|
||||
ELSE chosen_assistants
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
op.drop_column("user", "chosen_assistants")
|
||||
|
||||
op.alter_column(
|
||||
"user", "chosen_assistants_old", new_column_name="chosen_assistants"
|
||||
)
|
||||
@@ -1,30 +0,0 @@
|
||||
"""add-group-sync-time
|
||||
|
||||
Revision ID: 2daa494a0851
|
||||
Revises: c0fd6e4da83a
|
||||
Create Date: 2024-11-11 10:57:22.991157
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2daa494a0851"
|
||||
down_revision = "c0fd6e4da83a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"last_time_external_group_sync",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "last_time_external_group_sync")
|
||||
@@ -1,50 +0,0 @@
|
||||
"""single tool call per message
|
||||
|
||||
Revision ID: 33cb72ea4d80
|
||||
Revises: 5b29123cd710
|
||||
Create Date: 2024-11-01 12:51:01.535003
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "33cb72ea4d80"
|
||||
down_revision = "5b29123cd710"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Step 1: Delete extraneous ToolCall entries
|
||||
# Keep only the ToolCall with the smallest 'id' for each 'message_id'
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM tool_call
|
||||
WHERE id NOT IN (
|
||||
SELECT MIN(id)
|
||||
FROM tool_call
|
||||
WHERE message_id IS NOT NULL
|
||||
GROUP BY message_id
|
||||
);
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 2: Add a unique constraint on message_id
|
||||
op.create_unique_constraint(
|
||||
constraint_name="uq_tool_call_message_id",
|
||||
table_name="tool_call",
|
||||
columns=["message_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Step 1: Drop the unique constraint on message_id
|
||||
op.drop_constraint(
|
||||
constraint_name="uq_tool_call_message_id",
|
||||
table_name="tool_call",
|
||||
type_="unique",
|
||||
)
|
||||
@@ -1,45 +0,0 @@
|
||||
"""add persona categories
|
||||
|
||||
Revision ID: 47e5bef3a1d7
|
||||
Revises: dfbe9e93d3c7
|
||||
Create Date: 2024-11-05 18:55:02.221064
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "47e5bef3a1d7"
|
||||
down_revision = "dfbe9e93d3c7"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create the persona_category table
|
||||
op.create_table(
|
||||
"persona_category",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
|
||||
# Add category_id to persona table
|
||||
op.add_column("persona", sa.Column("category_id", sa.Integer(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
"fk_persona_category",
|
||||
"persona",
|
||||
"persona_category",
|
||||
["category_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint("fk_persona_category", "persona", type_="foreignkey")
|
||||
op.drop_column("persona", "category_id")
|
||||
op.drop_table("persona_category")
|
||||
@@ -1,280 +0,0 @@
|
||||
"""add_multiple_slack_bot_support
|
||||
|
||||
Revision ID: 4ee1287bd26a
|
||||
Revises: 47e5bef3a1d7
|
||||
Create Date: 2024-11-06 13:15:53.302644
|
||||
|
||||
"""
|
||||
import logging
|
||||
from typing import cast
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.orm import Session
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.db.models import SlackBot
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4ee1287bd26a"
|
||||
down_revision = "47e5bef3a1d7"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
logger.info(f"{revision}: create_table: slack_bot")
|
||||
# Create new slack_bot table
|
||||
op.create_table(
|
||||
"slack_bot",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("enabled", sa.Boolean(), nullable=False, server_default="true"),
|
||||
sa.Column("bot_token", sa.LargeBinary(), nullable=False),
|
||||
sa.Column("app_token", sa.LargeBinary(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("bot_token"),
|
||||
sa.UniqueConstraint("app_token"),
|
||||
)
|
||||
|
||||
# # Create new slack_channel_config table
|
||||
op.create_table(
|
||||
"slack_channel_config",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("slack_bot_id", sa.Integer(), nullable=True),
|
||||
sa.Column("persona_id", sa.Integer(), nullable=True),
|
||||
sa.Column("channel_config", postgresql.JSONB(), nullable=False),
|
||||
sa.Column("response_type", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"enable_auto_filters", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["slack_bot_id"],
|
||||
["slack_bot.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Handle existing Slack bot tokens first
|
||||
logger.info(f"{revision}: Checking for existing Slack bot.")
|
||||
bot_token = None
|
||||
app_token = None
|
||||
first_row_id = None
|
||||
|
||||
try:
|
||||
tokens = cast(dict, get_kv_store().load("slack_bot_tokens_config_key"))
|
||||
except Exception:
|
||||
logger.warning("No existing Slack bot tokens found.")
|
||||
tokens = {}
|
||||
|
||||
bot_token = tokens.get("bot_token")
|
||||
app_token = tokens.get("app_token")
|
||||
|
||||
if bot_token and app_token:
|
||||
logger.info(f"{revision}: Found bot and app tokens.")
|
||||
|
||||
session = Session(bind=op.get_bind())
|
||||
new_slack_bot = SlackBot(
|
||||
name="Slack Bot (Migrated)",
|
||||
enabled=True,
|
||||
bot_token=bot_token,
|
||||
app_token=app_token,
|
||||
)
|
||||
session.add(new_slack_bot)
|
||||
session.commit()
|
||||
first_row_id = new_slack_bot.id
|
||||
|
||||
# Create a default bot if none exists
|
||||
# This is in case there are no slack tokens but there are channels configured
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO slack_bot (name, enabled, bot_token, app_token)
|
||||
SELECT 'Default Bot', true, '', ''
|
||||
WHERE NOT EXISTS (SELECT 1 FROM slack_bot)
|
||||
RETURNING id;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Get the bot ID to use (either from existing migration or newly created)
|
||||
bot_id_query = sa.text(
|
||||
"""
|
||||
SELECT COALESCE(
|
||||
:first_row_id,
|
||||
(SELECT id FROM slack_bot ORDER BY id ASC LIMIT 1)
|
||||
) as bot_id;
|
||||
"""
|
||||
)
|
||||
result = op.get_bind().execute(bot_id_query, {"first_row_id": first_row_id})
|
||||
bot_id = result.scalar()
|
||||
|
||||
# CTE (Common Table Expression) that transforms the old slack_bot_config table data
|
||||
# This splits up the channel_names into their own rows
|
||||
channel_names_cte = """
|
||||
WITH channel_names AS (
|
||||
SELECT
|
||||
sbc.id as config_id,
|
||||
sbc.persona_id,
|
||||
sbc.response_type,
|
||||
sbc.enable_auto_filters,
|
||||
jsonb_array_elements_text(sbc.channel_config->'channel_names') as channel_name,
|
||||
sbc.channel_config->>'respond_tag_only' as respond_tag_only,
|
||||
sbc.channel_config->>'respond_to_bots' as respond_to_bots,
|
||||
sbc.channel_config->'respond_member_group_list' as respond_member_group_list,
|
||||
sbc.channel_config->'answer_filters' as answer_filters,
|
||||
sbc.channel_config->'follow_up_tags' as follow_up_tags
|
||||
FROM slack_bot_config sbc
|
||||
)
|
||||
"""
|
||||
|
||||
# Insert the channel names into the new slack_channel_config table
|
||||
insert_statement = """
|
||||
INSERT INTO slack_channel_config (
|
||||
slack_bot_id,
|
||||
persona_id,
|
||||
channel_config,
|
||||
response_type,
|
||||
enable_auto_filters
|
||||
)
|
||||
SELECT
|
||||
:bot_id,
|
||||
channel_name.persona_id,
|
||||
jsonb_build_object(
|
||||
'channel_name', channel_name.channel_name,
|
||||
'respond_tag_only',
|
||||
COALESCE((channel_name.respond_tag_only)::boolean, false),
|
||||
'respond_to_bots',
|
||||
COALESCE((channel_name.respond_to_bots)::boolean, false),
|
||||
'respond_member_group_list',
|
||||
COALESCE(channel_name.respond_member_group_list, '[]'::jsonb),
|
||||
'answer_filters',
|
||||
COALESCE(channel_name.answer_filters, '[]'::jsonb),
|
||||
'follow_up_tags',
|
||||
COALESCE(channel_name.follow_up_tags, '[]'::jsonb)
|
||||
),
|
||||
channel_name.response_type,
|
||||
channel_name.enable_auto_filters
|
||||
FROM channel_names channel_name;
|
||||
"""
|
||||
|
||||
op.execute(sa.text(channel_names_cte + insert_statement).bindparams(bot_id=bot_id))
|
||||
|
||||
# Clean up old tokens if they existed
|
||||
try:
|
||||
if bot_token and app_token:
|
||||
logger.info(f"{revision}: Removing old bot and app tokens.")
|
||||
get_kv_store().delete("slack_bot_tokens_config_key")
|
||||
except Exception:
|
||||
logger.warning("tried to delete tokens in dynamic config but failed")
|
||||
# Rename the table
|
||||
op.rename_table(
|
||||
"slack_bot_config__standard_answer_category",
|
||||
"slack_channel_config__standard_answer_category",
|
||||
)
|
||||
|
||||
# Rename the column
|
||||
op.alter_column(
|
||||
"slack_channel_config__standard_answer_category",
|
||||
"slack_bot_config_id",
|
||||
new_column_name="slack_channel_config_id",
|
||||
)
|
||||
|
||||
# Drop the table with CASCADE to handle dependent objects
|
||||
op.execute("DROP TABLE slack_bot_config CASCADE")
|
||||
|
||||
logger.info(f"{revision}: Migration complete.")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Recreate the old slack_bot_config table
|
||||
op.create_table(
|
||||
"slack_bot_config",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("persona_id", sa.Integer(), nullable=True),
|
||||
sa.Column("channel_config", postgresql.JSONB(), nullable=False),
|
||||
sa.Column("response_type", sa.String(), nullable=False),
|
||||
sa.Column("enable_auto_filters", sa.Boolean(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Migrate data back to the old format
|
||||
# Group by persona_id to combine channel names back into arrays
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO slack_bot_config (
|
||||
persona_id,
|
||||
channel_config,
|
||||
response_type,
|
||||
enable_auto_filters
|
||||
)
|
||||
SELECT DISTINCT ON (persona_id)
|
||||
persona_id,
|
||||
jsonb_build_object(
|
||||
'channel_names', (
|
||||
SELECT jsonb_agg(c.channel_config->>'channel_name')
|
||||
FROM slack_channel_config c
|
||||
WHERE c.persona_id = scc.persona_id
|
||||
),
|
||||
'respond_tag_only', (channel_config->>'respond_tag_only')::boolean,
|
||||
'respond_to_bots', (channel_config->>'respond_to_bots')::boolean,
|
||||
'respond_member_group_list', channel_config->'respond_member_group_list',
|
||||
'answer_filters', channel_config->'answer_filters',
|
||||
'follow_up_tags', channel_config->'follow_up_tags'
|
||||
),
|
||||
response_type,
|
||||
enable_auto_filters
|
||||
FROM slack_channel_config scc
|
||||
WHERE persona_id IS NOT NULL;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Rename the table back
|
||||
op.rename_table(
|
||||
"slack_channel_config__standard_answer_category",
|
||||
"slack_bot_config__standard_answer_category",
|
||||
)
|
||||
|
||||
# Rename the column back
|
||||
op.alter_column(
|
||||
"slack_bot_config__standard_answer_category",
|
||||
"slack_channel_config_id",
|
||||
new_column_name="slack_bot_config_id",
|
||||
)
|
||||
|
||||
# Try to save the first bot's tokens back to KV store
|
||||
try:
|
||||
first_bot = (
|
||||
op.get_bind()
|
||||
.execute(
|
||||
sa.text(
|
||||
"SELECT bot_token, app_token FROM slack_bot ORDER BY id LIMIT 1"
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if first_bot and first_bot.bot_token and first_bot.app_token:
|
||||
tokens = {
|
||||
"bot_token": first_bot.bot_token,
|
||||
"app_token": first_bot.app_token,
|
||||
}
|
||||
get_kv_store().store("slack_bot_tokens_config_key", tokens)
|
||||
except Exception:
|
||||
logger.warning("Failed to save tokens back to KV store")
|
||||
|
||||
# Drop the new tables in reverse order
|
||||
op.drop_table("slack_channel_config")
|
||||
op.drop_table("slack_bot")
|
||||
@@ -1,70 +0,0 @@
|
||||
"""nullable search settings for historic index attempts
|
||||
|
||||
Revision ID: 5b29123cd710
|
||||
Revises: 949b4a92a401
|
||||
Create Date: 2024-10-30 19:37:59.630704
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5b29123cd710"
|
||||
down_revision = "949b4a92a401"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the existing foreign key constraint
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Modify the column to be nullable
|
||||
op.alter_column(
|
||||
"index_attempt", "search_settings_id", existing_type=sa.INTEGER(), nullable=True
|
||||
)
|
||||
|
||||
# Add back the foreign key with ON DELETE SET NULL
|
||||
op.create_foreign_key(
|
||||
"fk_index_attempt_search_settings",
|
||||
"index_attempt",
|
||||
"search_settings",
|
||||
["search_settings_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Warning: This will delete all index attempts that don't have search settings
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM index_attempt
|
||||
WHERE search_settings_id IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop foreign key constraint
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Modify the column to be not nullable
|
||||
op.alter_column(
|
||||
"index_attempt",
|
||||
"search_settings_id",
|
||||
existing_type=sa.INTEGER(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Add back the foreign key without ON DELETE SET NULL
|
||||
op.create_foreign_key(
|
||||
"fk_index_attempt_search_settings",
|
||||
"index_attempt",
|
||||
"search_settings",
|
||||
["search_settings_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -1,45 +0,0 @@
|
||||
"""remove default bot
|
||||
|
||||
Revision ID: 6d562f86c78b
|
||||
Revises: 177de57c21c9
|
||||
Create Date: 2024-11-22 11:51:29.331336
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6d562f86c78b"
|
||||
down_revision = "177de57c21c9"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM slack_bot
|
||||
WHERE name = 'Default Bot'
|
||||
AND bot_token = ''
|
||||
AND app_token = ''
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM slack_channel_config
|
||||
WHERE slack_channel_config.slack_bot_id = slack_bot.id
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO slack_bot (name, enabled, bot_token, app_token)
|
||||
SELECT 'Default Bot', true, '', ''
|
||||
WHERE NOT EXISTS (SELECT 1 FROM slack_bot)
|
||||
RETURNING id;
|
||||
"""
|
||||
)
|
||||
)
|
||||
@@ -9,8 +9,8 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.context.search.enums import RecencyBiasSetting
|
||||
from danswer.context.search.enums import SearchType
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.search.enums import SearchType
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "776b3bbe9092"
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
"""add web ui option to slack config
|
||||
|
||||
Revision ID: 93560ba1b118
|
||||
Revises: 6d562f86c78b
|
||||
Create Date: 2024-11-24 06:36:17.490612
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "93560ba1b118"
|
||||
down_revision = "6d562f86c78b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add show_continue_in_web_ui with default False to all existing channel_configs
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE slack_channel_config
|
||||
SET channel_config = channel_config || '{"show_continue_in_web_ui": false}'::jsonb
|
||||
WHERE NOT channel_config ? 'show_continue_in_web_ui'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove show_continue_in_web_ui from all channel_configs
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE slack_channel_config
|
||||
SET channel_config = channel_config - 'show_continue_in_web_ui'
|
||||
"""
|
||||
)
|
||||
@@ -7,7 +7,6 @@ Create Date: 2024-10-26 13:06:06.937969
|
||||
"""
|
||||
from alembic import op
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
|
||||
# Import your models and constants
|
||||
from danswer.db.models import (
|
||||
@@ -16,6 +15,7 @@ from danswer.db.models import (
|
||||
Credential,
|
||||
IndexAttempt,
|
||||
)
|
||||
from danswer.configs.constants import DocumentSource
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
@@ -30,11 +30,13 @@ def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
session = Session(bind=bind)
|
||||
|
||||
# Get connectors using raw SQL
|
||||
result = bind.execute(
|
||||
text("SELECT id FROM connector WHERE source = 'requesttracker'")
|
||||
connectors_to_delete = (
|
||||
session.query(Connector)
|
||||
.filter(Connector.source == DocumentSource.REQUESTTRACKER)
|
||||
.all()
|
||||
)
|
||||
connector_ids = [row[0] for row in result]
|
||||
|
||||
connector_ids = [connector.id for connector in connectors_to_delete]
|
||||
|
||||
if connector_ids:
|
||||
cc_pairs_to_delete = (
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
"""add creator to cc pair
|
||||
|
||||
Revision ID: 9cf5c00f72fe
|
||||
Revises: 26b931506ecb
|
||||
Create Date: 2024-11-12 15:16:42.682902
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9cf5c00f72fe"
|
||||
down_revision = "26b931506ecb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"creator_id",
|
||||
sa.UUID(as_uuid=True),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "creator_id")
|
||||
@@ -1,36 +0,0 @@
|
||||
"""Combine Search and Chat
|
||||
|
||||
Revision ID: 9f696734098f
|
||||
Revises: a8c2065484e6
|
||||
Create Date: 2024-11-27 15:32:19.694972
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9f696734098f"
|
||||
down_revision = "a8c2065484e6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column("chat_session", "description", nullable=True)
|
||||
op.drop_column("chat_session", "one_shot")
|
||||
op.drop_column("slack_channel_config", "response_type")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("UPDATE chat_session SET description = '' WHERE description IS NULL")
|
||||
op.alter_column("chat_session", "description", nullable=False)
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column("one_shot", sa.Boolean(), nullable=False, server_default=sa.false()),
|
||||
)
|
||||
op.add_column(
|
||||
"slack_channel_config",
|
||||
sa.Column(
|
||||
"response_type", sa.String(), nullable=False, server_default="citations"
|
||||
),
|
||||
)
|
||||
@@ -1,27 +0,0 @@
|
||||
"""add auto scroll to user model
|
||||
|
||||
Revision ID: a8c2065484e6
|
||||
Revises: abe7378b8217
|
||||
Create Date: 2024-11-22 17:34:09.690295
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a8c2065484e6"
|
||||
down_revision = "abe7378b8217"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("auto_scroll", sa.Boolean(), nullable=True, server_default=None),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "auto_scroll")
|
||||
@@ -1,30 +0,0 @@
|
||||
"""add indexing trigger to cc_pair
|
||||
|
||||
Revision ID: abe7378b8217
|
||||
Revises: 6d562f86c78b
|
||||
Create Date: 2024-11-26 19:09:53.481171
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "abe7378b8217"
|
||||
down_revision = "93560ba1b118"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"indexing_trigger",
|
||||
sa.Enum("UPDATE", "REINDEX", name="indexingmode", native_enum=False),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "indexing_trigger")
|
||||
@@ -288,15 +288,6 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# NOTE: you will lose all chat history. This is to satisfy the non-nullable constraints
|
||||
# below
|
||||
op.execute("DELETE FROM chat_feedback")
|
||||
op.execute("DELETE FROM chat_message__search_doc")
|
||||
op.execute("DELETE FROM document_retrieval_feedback")
|
||||
op.execute("DELETE FROM document_retrieval_feedback")
|
||||
op.execute("DELETE FROM chat_message")
|
||||
op.execute("DELETE FROM chat_session")
|
||||
|
||||
op.drop_constraint(
|
||||
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
|
||||
)
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
"""remove description from starter messages
|
||||
|
||||
Revision ID: b72ed7a5db0e
|
||||
Revises: 33cb72ea4d80
|
||||
Create Date: 2024-11-03 15:55:28.944408
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b72ed7a5db0e"
|
||||
down_revision = "33cb72ea4d80"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET starter_messages = (
|
||||
SELECT jsonb_agg(elem - 'description')
|
||||
FROM jsonb_array_elements(starter_messages) elem
|
||||
)
|
||||
WHERE starter_messages IS NOT NULL
|
||||
AND jsonb_typeof(starter_messages) = 'array'
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET starter_messages = (
|
||||
SELECT jsonb_agg(elem || '{"description": ""}')
|
||||
FROM jsonb_array_elements(starter_messages) elem
|
||||
)
|
||||
WHERE starter_messages IS NOT NULL
|
||||
AND jsonb_typeof(starter_messages) = 'array'
|
||||
"""
|
||||
)
|
||||
)
|
||||
@@ -1,29 +0,0 @@
|
||||
"""add recent assistants
|
||||
|
||||
Revision ID: c0fd6e4da83a
|
||||
Revises: b72ed7a5db0e
|
||||
Create Date: 2024-11-03 17:28:54.916618
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c0fd6e4da83a"
|
||||
down_revision = "b72ed7a5db0e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "recent_assistants")
|
||||
@@ -23,56 +23,6 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Delete chat messages and feedback first since they reference chat sessions
|
||||
# Get chat messages from sessions with null persona_id
|
||||
chat_messages_query = """
|
||||
SELECT id
|
||||
FROM chat_message
|
||||
WHERE chat_session_id IN (
|
||||
SELECT id
|
||||
FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
)
|
||||
"""
|
||||
|
||||
# Delete dependent records first
|
||||
op.execute(
|
||||
f"""
|
||||
DELETE FROM document_retrieval_feedback
|
||||
WHERE chat_message_id IN (
|
||||
{chat_messages_query}
|
||||
)
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
f"""
|
||||
DELETE FROM chat_message__search_doc
|
||||
WHERE chat_message_id IN (
|
||||
{chat_messages_query}
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Delete chat messages
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM chat_message
|
||||
WHERE chat_session_id IN (
|
||||
SELECT id
|
||||
FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Now we can safely delete the chat sessions
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"persona_id",
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
"""extended_role_for_non_web
|
||||
|
||||
Revision ID: dfbe9e93d3c7
|
||||
Revises: 9cf5c00f72fe
|
||||
Create Date: 2024-11-16 07:54:18.727906
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "dfbe9e93d3c7"
|
||||
down_revision = "9cf5c00f72fe"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET role = 'EXT_PERM_USER'
|
||||
WHERE has_web_login = false
|
||||
"""
|
||||
)
|
||||
op.drop_column("user", "has_web_login")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("has_web_login", sa.Boolean(), nullable=False, server_default="true"),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET has_web_login = false,
|
||||
role = 'BASIC'
|
||||
WHERE role IN ('SLACK_USER', 'EXT_PERM_USER')
|
||||
"""
|
||||
)
|
||||
@@ -1,40 +0,0 @@
|
||||
"""non-nullbale slack bot id in channel config
|
||||
|
||||
Revision ID: f7a894b06d02
|
||||
Revises: 9f696734098f
|
||||
Create Date: 2024-12-06 12:55:42.845723
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f7a894b06d02"
|
||||
down_revision = "9f696734098f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Delete all rows with null slack_bot_id
|
||||
op.execute("DELETE FROM slack_channel_config WHERE slack_bot_id IS NULL")
|
||||
|
||||
# Make slack_bot_id non-nullable
|
||||
op.alter_column(
|
||||
"slack_channel_config",
|
||||
"slack_bot_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Make slack_bot_id nullable again
|
||||
op.alter_column(
|
||||
"slack_channel_config",
|
||||
"slack_bot_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=True,
|
||||
)
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
@@ -38,15 +37,8 @@ EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem,
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
name: str,
|
||||
type_: str,
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
) -> bool:
|
||||
|
||||
@@ -16,46 +16,6 @@ class ExternalAccess:
|
||||
is_public: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocExternalAccess:
|
||||
"""
|
||||
This is just a class to wrap the external access and the document ID
|
||||
together. It's used for syncing document permissions to Redis.
|
||||
"""
|
||||
|
||||
external_access: ExternalAccess
|
||||
# The document ID
|
||||
doc_id: str
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"external_access": {
|
||||
"external_user_emails": list(self.external_access.external_user_emails),
|
||||
"external_user_group_ids": list(
|
||||
self.external_access.external_user_group_ids
|
||||
),
|
||||
"is_public": self.external_access.is_public,
|
||||
},
|
||||
"doc_id": self.doc_id,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "DocExternalAccess":
|
||||
external_access = ExternalAccess(
|
||||
external_user_emails=set(
|
||||
data["external_access"].get("external_user_emails", [])
|
||||
),
|
||||
external_user_group_ids=set(
|
||||
data["external_access"].get("external_user_group_ids", [])
|
||||
),
|
||||
is_public=data["external_access"]["is_public"],
|
||||
)
|
||||
return cls(
|
||||
external_access=external_access,
|
||||
doc_id=data["doc_id"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocumentAccess(ExternalAccess):
|
||||
# User emails for Danswer users, None indicates admin
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
import secrets
|
||||
import uuid
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import unquote
|
||||
|
||||
from fastapi import Request
|
||||
from passlib.hash import sha256_crypt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.app_configs import API_KEY_HASH_ROUNDS
|
||||
|
||||
|
||||
_API_KEY_HEADER_NAME = "Authorization"
|
||||
# NOTE for others who are curious: In the context of a header, "X-" often refers
|
||||
# to non-standard, experimental, or custom headers in HTTP or other protocols. It
|
||||
# indicates that the header is not part of the official standards defined by
|
||||
# organizations like the Internet Engineering Task Force (IETF).
|
||||
_API_KEY_HEADER_ALTERNATIVE_NAME = "X-Danswer-Authorization"
|
||||
_BEARER_PREFIX = "Bearer "
|
||||
_API_KEY_PREFIX = "dn_"
|
||||
_API_KEY_LEN = 192
|
||||
|
||||
|
||||
class ApiKeyDescriptor(BaseModel):
|
||||
api_key_id: int
|
||||
api_key_display: str
|
||||
api_key: str | None = None # only present on initial creation
|
||||
api_key_name: str | None = None
|
||||
api_key_role: UserRole
|
||||
|
||||
user_id: uuid.UUID
|
||||
|
||||
|
||||
def generate_api_key(tenant_id: str | None = None) -> str:
|
||||
# For backwards compatibility, if no tenant_id, generate old style key
|
||||
if not tenant_id:
|
||||
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
|
||||
|
||||
encoded_tenant = quote(tenant_id) # URL encode the tenant ID
|
||||
return f"{_API_KEY_PREFIX}{encoded_tenant}.{secrets.token_urlsafe(_API_KEY_LEN)}"
|
||||
|
||||
|
||||
def extract_tenant_from_api_key_header(request: Request) -> str | None:
|
||||
"""Extract tenant ID from request. Returns None if auth is disabled or invalid format."""
|
||||
raw_api_key_header = request.headers.get(
|
||||
_API_KEY_HEADER_ALTERNATIVE_NAME
|
||||
) or request.headers.get(_API_KEY_HEADER_NAME)
|
||||
|
||||
if not raw_api_key_header or not raw_api_key_header.startswith(_BEARER_PREFIX):
|
||||
return None
|
||||
|
||||
api_key = raw_api_key_header[len(_BEARER_PREFIX) :].strip()
|
||||
|
||||
if not api_key.startswith(_API_KEY_PREFIX):
|
||||
return None
|
||||
|
||||
parts = api_key[len(_API_KEY_PREFIX) :].split(".", 1)
|
||||
if len(parts) != 2:
|
||||
return None
|
||||
|
||||
tenant_id = parts[0]
|
||||
return unquote(tenant_id) if tenant_id else None
|
||||
|
||||
|
||||
def hash_api_key(api_key: str) -> str:
|
||||
# NOTE: no salt is needed, as the API key is randomly generated
|
||||
# and overlaps are impossible
|
||||
return sha256_crypt.hash(api_key, salt="", rounds=API_KEY_HASH_ROUNDS)
|
||||
|
||||
|
||||
def build_displayable_api_key(api_key: str) -> str:
|
||||
if api_key.startswith(_API_KEY_PREFIX):
|
||||
api_key = api_key[len(_API_KEY_PREFIX) :]
|
||||
|
||||
return _API_KEY_PREFIX + api_key[:4] + "********" + api_key[-4:]
|
||||
|
||||
|
||||
def get_hashed_api_key_from_request(request: Request) -> str | None:
|
||||
raw_api_key_header = request.headers.get(
|
||||
_API_KEY_HEADER_ALTERNATIVE_NAME
|
||||
) or request.headers.get(_API_KEY_HEADER_NAME)
|
||||
if raw_api_key_header is None:
|
||||
return None
|
||||
|
||||
if raw_api_key_header.startswith(_BEARER_PREFIX):
|
||||
raw_api_key_header = raw_api_key_header[len(_BEARER_PREFIX) :].strip()
|
||||
|
||||
return hash_api_key(raw_api_key_header)
|
||||
@@ -2,8 +2,8 @@ from typing import cast
|
||||
|
||||
from danswer.configs.constants import KV_USER_STORE_KEY
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.key_value_store.interface import JSON_ro
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.utils.special_types import JSON_ro
|
||||
|
||||
|
||||
def get_invited_users() -> list[str]:
|
||||
|
||||
@@ -23,9 +23,7 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
|
||||
)
|
||||
return UserPreferences(**preferences_data)
|
||||
except KvKeyNotFoundError:
|
||||
return UserPreferences(
|
||||
chosen_assistants=None, default_model=None, auto_scroll=True
|
||||
)
|
||||
return UserPreferences(chosen_assistants=None, default_model=None)
|
||||
|
||||
|
||||
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:
|
||||
|
||||
@@ -13,24 +13,12 @@ class UserRole(str, Enum):
|
||||
groups they are curators of
|
||||
- Global Curator can perform admin actions
|
||||
for all groups they are a member of
|
||||
- Limited can access a limited set of basic api endpoints
|
||||
- Slack are users that have used danswer via slack but dont have a web login
|
||||
- External permissioned users that have been picked up during the external permissions sync process but don't have a web login
|
||||
"""
|
||||
|
||||
LIMITED = "limited"
|
||||
BASIC = "basic"
|
||||
ADMIN = "admin"
|
||||
CURATOR = "curator"
|
||||
GLOBAL_CURATOR = "global_curator"
|
||||
SLACK_USER = "slack_user"
|
||||
EXT_PERM_USER = "ext_perm_user"
|
||||
|
||||
def is_web_login(self) -> bool:
|
||||
return self not in [
|
||||
UserRole.SLACK_USER,
|
||||
UserRole.EXT_PERM_USER,
|
||||
]
|
||||
|
||||
|
||||
class UserStatus(str, Enum):
|
||||
@@ -45,8 +33,10 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
role: UserRole = UserRole.BASIC
|
||||
has_web_login: bool | None = True
|
||||
tenant_id: str | None = None
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
role: UserRole
|
||||
has_web_login: bool | None = True
|
||||
|
||||
@@ -48,16 +48,18 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
|
||||
from httpx_oauth.oauth2 import BaseOAuth2
|
||||
from httpx_oauth.oauth2 import OAuth2Token
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import attributes
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.api_key import get_hashed_api_key_from_request
|
||||
from danswer.auth.invited_users import get_invited_users
|
||||
from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
from danswer.configs.app_configs import DISABLE_VERIFICATION
|
||||
from danswer.configs.app_configs import EMAIL_FROM
|
||||
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
@@ -73,28 +75,28 @@ from danswer.configs.constants import AuthType
|
||||
from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from danswer.db.api_key import fetch_user_for_api_key
|
||||
from danswer.db.auth import get_access_token_db
|
||||
from danswer.db.auth import get_default_admin_user_emails
|
||||
from danswer.db.auth import get_user_count
|
||||
from danswer.db.auth import get_user_db
|
||||
from danswer.db.auth import SQLAlchemyUserAdminDB
|
||||
from danswer.db.engine import get_async_session
|
||||
from danswer.db.engine import get_async_session_with_tenant
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import OAuthAccount
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import UserTenantMapping
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.server.utils import BasicAuthenticationError
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import async_return_default_schema
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -131,12 +133,11 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
|
||||
|
||||
|
||||
def user_needs_to_be_verified() -> bool:
|
||||
if AUTH_TYPE == AuthType.BASIC:
|
||||
return REQUIRE_EMAIL_VERIFICATION
|
||||
|
||||
# For other auth types, if the user is authenticated it's assumed that
|
||||
# the user is already verified via the external IDP
|
||||
return False
|
||||
# all other auth types besides basic should require users to be
|
||||
# verified
|
||||
return not DISABLE_VERIFICATION and (
|
||||
AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||
)
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
@@ -189,6 +190,20 @@ def verify_email_domain(email: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def get_tenant_id_for_email(email: str) -> str:
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
# Implement logic to get tenant_id from the mapping table
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
result = db_session.execute(
|
||||
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
|
||||
)
|
||||
tenant_id = result.scalar_one_or_none()
|
||||
if tenant_id is None:
|
||||
raise exceptions.UserNotExists()
|
||||
return tenant_id
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
@@ -217,26 +232,25 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
|
||||
user_db: SQLAlchemyUserDatabase[User, uuid.UUID]
|
||||
|
||||
async def create(
|
||||
self,
|
||||
user_create: schemas.UC | UserCreate,
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> User:
|
||||
referral_source = None
|
||||
if request is not None:
|
||||
referral_source = request.cookies.get("referral_source", None)
|
||||
try:
|
||||
tenant_id = (
|
||||
get_tenant_id_for_email(user_create.email)
|
||||
if MULTI_TENANT
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
except exceptions.UserNotExists:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=user_create.email,
|
||||
referral_source=referral_source,
|
||||
)
|
||||
if not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="User does not belong to an organization"
|
||||
)
|
||||
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
@@ -244,9 +258,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
if MULTI_TENANT:
|
||||
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
|
||||
db_session, User, OAuthAccount
|
||||
)
|
||||
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
|
||||
self.user_db = tenant_user_db
|
||||
self.database = tenant_user_db
|
||||
|
||||
@@ -259,15 +271,20 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
|
||||
user = None
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if not user.role.is_web_login() and user_create.role.is_web_login():
|
||||
if (
|
||||
not user.has_web_login
|
||||
and hasattr(user_create, "has_web_login")
|
||||
and user_create.has_web_login
|
||||
):
|
||||
user_update = UserUpdate(
|
||||
password=user_create.password,
|
||||
has_web_login=True,
|
||||
role=user_create.role,
|
||||
is_verified=user_create.is_verified,
|
||||
)
|
||||
@@ -275,13 +292,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
else:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
return user
|
||||
|
||||
async def oauth_callback(
|
||||
self,
|
||||
self: "BaseUserManager[models.UOAP, models.ID]",
|
||||
oauth_name: str,
|
||||
access_token: str,
|
||||
account_id: str,
|
||||
@@ -292,24 +307,20 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
*,
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
) -> User:
|
||||
referral_source = None
|
||||
if request:
|
||||
referral_source = getattr(request.state, "referral_source", None)
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=account_email,
|
||||
referral_source=referral_source,
|
||||
)
|
||||
) -> models.UOAP:
|
||||
# Get tenant_id from mapping table
|
||||
try:
|
||||
tenant_id = (
|
||||
get_tenant_id_for_email(account_email)
|
||||
if MULTI_TENANT
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
except exceptions.UserNotExists:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
# Proceed with the tenant context
|
||||
token = None
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
@@ -318,11 +329,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
verify_email_domain(account_email)
|
||||
|
||||
if MULTI_TENANT:
|
||||
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
|
||||
db_session, User, OAuthAccount
|
||||
)
|
||||
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
|
||||
self.user_db = tenant_user_db
|
||||
self.database = tenant_user_db
|
||||
self.database = tenant_user_db # type: ignore
|
||||
|
||||
oauth_account_dict = {
|
||||
"oauth_name": oauth_name,
|
||||
@@ -362,9 +371,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
# Explicitly set the Postgres schema for this session to ensure
|
||||
# OAuth account creation happens in the correct tenant schema
|
||||
await db_session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
|
||||
# Add OAuth account
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
await self.on_after_register(user, request)
|
||||
|
||||
else:
|
||||
@@ -374,11 +383,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
and existing_oauth_account.oauth_name == oauth_name
|
||||
):
|
||||
user = await self.user_db.update_oauth_account(
|
||||
user,
|
||||
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
|
||||
# but the type checker doesn't know that :(
|
||||
existing_oauth_account, # type: ignore
|
||||
oauth_account_dict,
|
||||
user, existing_oauth_account, oauth_account_dict
|
||||
)
|
||||
|
||||
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
||||
@@ -391,15 +396,16 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
)
|
||||
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if not user.role.is_web_login():
|
||||
if not user.has_web_login: # type: ignore
|
||||
await self.user_db.update(
|
||||
user,
|
||||
{
|
||||
"is_verified": is_verified_by_default,
|
||||
"role": UserRole.BASIC,
|
||||
"has_web_login": True,
|
||||
},
|
||||
)
|
||||
user.is_verified = is_verified_by_default
|
||||
user.has_web_login = True # type: ignore
|
||||
|
||||
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
|
||||
# otherwise, the oidc expiry will always be old, and the user will never be able to login
|
||||
@@ -447,13 +453,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
email = credentials.username
|
||||
|
||||
# Get tenant_id from mapping table
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=email,
|
||||
)
|
||||
tenant_id = get_tenant_id_for_email(email)
|
||||
if not tenant_id:
|
||||
# User not found in mapping
|
||||
self.password_helper.hash(credentials.password)
|
||||
@@ -474,8 +474,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
self.password_helper.hash(credentials.password)
|
||||
return None
|
||||
|
||||
if not user.role.is_web_login():
|
||||
raise BasicAuthenticationError(
|
||||
has_web_login = attributes.get_attribute(user, "has_web_login")
|
||||
|
||||
if not has_web_login:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
||||
)
|
||||
|
||||
@@ -507,30 +510,19 @@ cookie_transport = CookieTransport(
|
||||
|
||||
# This strategy is used to add tenant_id to the JWT token
|
||||
class TenantAwareJWTStrategy(JWTStrategy):
|
||||
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=user.email,
|
||||
)
|
||||
|
||||
async def write_token(self, user: User) -> str:
|
||||
tenant_id = get_tenant_id_for_email(user.email)
|
||||
data = {
|
||||
"sub": str(user.id),
|
||||
"aud": self.token_audience,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
return data
|
||||
|
||||
async def write_token(self, user: User) -> str:
|
||||
data = await self._create_token_data(user)
|
||||
return generate_jwt(
|
||||
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
|
||||
)
|
||||
|
||||
|
||||
def get_jwt_strategy() -> TenantAwareJWTStrategy:
|
||||
def get_jwt_strategy() -> JWTStrategy:
|
||||
return TenantAwareJWTStrategy(
|
||||
secret=USER_AUTH_SECRET,
|
||||
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
|
||||
@@ -605,7 +597,7 @@ optional_fastapi_current_user = fastapi_users.current_user(active=True, optional
|
||||
async def optional_user_(
|
||||
request: Request,
|
||||
user: User | None,
|
||||
async_db_session: AsyncSession,
|
||||
db_session: Session,
|
||||
) -> User | None:
|
||||
"""NOTE: `request` and `db_session` are not used here, but are included
|
||||
for the EE version of this function."""
|
||||
@@ -614,21 +606,13 @@ async def optional_user_(
|
||||
|
||||
async def optional_user(
|
||||
request: Request,
|
||||
async_db_session: AsyncSession = Depends(get_async_session),
|
||||
db_session: Session = Depends(get_session),
|
||||
user: User | None = Depends(optional_fastapi_current_user),
|
||||
) -> User | None:
|
||||
versioned_fetch_user = fetch_versioned_implementation(
|
||||
"danswer.auth.users", "optional_user_"
|
||||
)
|
||||
user = await versioned_fetch_user(request, user, async_db_session)
|
||||
|
||||
# check if an API key is present
|
||||
if user is None:
|
||||
hashed_api_key = get_hashed_api_key_from_request(request)
|
||||
if hashed_api_key:
|
||||
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
|
||||
|
||||
return user
|
||||
return await versioned_fetch_user(request, user, db_session)
|
||||
|
||||
|
||||
async def double_check_user(
|
||||
@@ -640,12 +624,14 @@ async def double_check_user(
|
||||
return None
|
||||
|
||||
if user is None:
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not authenticated.",
|
||||
)
|
||||
|
||||
if user_needs_to_be_verified() and not user.is_verified:
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
@@ -654,7 +640,8 @@ async def double_check_user(
|
||||
and user.oidc_expiry < datetime.now(timezone.utc)
|
||||
and not include_expired
|
||||
):
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
@@ -667,24 +654,10 @@ async def current_user_with_expired_token(
|
||||
return await double_check_user(user, include_expired=True)
|
||||
|
||||
|
||||
async def current_limited_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
return await double_check_user(user)
|
||||
|
||||
|
||||
async def current_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
user = await double_check_user(user)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
if user.role == UserRole.LIMITED:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
|
||||
)
|
||||
return user
|
||||
return await double_check_user(user)
|
||||
|
||||
|
||||
async def current_curator_or_admin_user(
|
||||
@@ -694,13 +667,15 @@ async def current_curator_or_admin_user(
|
||||
return None
|
||||
|
||||
if not user or not hasattr(user, "role"):
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not authenticated or lacks role information.",
|
||||
)
|
||||
|
||||
allowed_roles = {UserRole.GLOBAL_CURATOR, UserRole.CURATOR, UserRole.ADMIN}
|
||||
if user.role not in allowed_roles:
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not a curator or admin.",
|
||||
)
|
||||
|
||||
@@ -712,7 +687,8 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
|
||||
return None
|
||||
|
||||
if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN:
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User must be an admin to perform this action.",
|
||||
)
|
||||
|
||||
@@ -740,6 +716,8 @@ def generate_state_token(
|
||||
|
||||
|
||||
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
|
||||
|
||||
|
||||
def create_danswer_oauth_router(
|
||||
oauth_client: BaseOAuth2,
|
||||
backend: AuthenticationBackend,
|
||||
@@ -789,22 +767,15 @@ def get_oauth_router(
|
||||
response_model=OAuth2AuthorizeResponse,
|
||||
)
|
||||
async def authorize(
|
||||
request: Request,
|
||||
scopes: List[str] = Query(None),
|
||||
request: Request, scopes: List[str] = Query(None)
|
||||
) -> OAuth2AuthorizeResponse:
|
||||
referral_source = request.cookies.get("referral_source", None)
|
||||
|
||||
if redirect_url is not None:
|
||||
authorize_redirect_url = redirect_url
|
||||
else:
|
||||
authorize_redirect_url = str(request.url_for(callback_route_name))
|
||||
|
||||
next_url = request.query_params.get("next", "/")
|
||||
|
||||
state_data: Dict[str, str] = {
|
||||
"next_url": next_url,
|
||||
"referral_source": referral_source or "default_referral",
|
||||
}
|
||||
state_data: Dict[str, str] = {"next_url": next_url}
|
||||
state = generate_state_token(state_data, state_secret)
|
||||
authorization_url = await oauth_client.get_authorization_url(
|
||||
authorize_redirect_url,
|
||||
@@ -863,11 +834,8 @@ def get_oauth_router(
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
next_url = state_data.get("next_url", "/")
|
||||
referral_source = state_data.get("referral_source", None)
|
||||
|
||||
request.state.referral_source = referral_source
|
||||
|
||||
# Proceed to authenticate or create the user
|
||||
# Authenticate user
|
||||
try:
|
||||
user = await user_manager.oauth_callback(
|
||||
oauth_client.name,
|
||||
@@ -909,25 +877,7 @@ def get_oauth_router(
|
||||
redirect_response.status_code = response.status_code
|
||||
if hasattr(response, "media_type"):
|
||||
redirect_response.media_type = response.media_type
|
||||
|
||||
return redirect_response
|
||||
|
||||
return router
|
||||
|
||||
|
||||
async def api_key_dep(
|
||||
request: Request, async_db_session: AsyncSession = Depends(get_async_session)
|
||||
) -> User | None:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return None
|
||||
|
||||
hashed_api_key = get_hashed_api_key_from_request(request)
|
||||
if not hashed_api_key:
|
||||
raise HTTPException(status_code=401, detail="Missing API key")
|
||||
|
||||
if hashed_api_key:
|
||||
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
return user
|
||||
|
||||
@@ -3,7 +3,6 @@ import multiprocessing
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import sentry_sdk
|
||||
from celery import Task
|
||||
from celery.app import trace
|
||||
@@ -11,26 +10,19 @@ from celery.exceptions import WorkerShutdown
|
||||
from celery.states import READY_STATES
|
||||
from celery.utils.log import get_task_logger
|
||||
from celery.worker import strategy # type: ignore
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
|
||||
from danswer.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from danswer.redis.redis_document_set import RedisDocumentSet
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.redis.redis_usergroup import RedisUserGroup
|
||||
from danswer.utils.logger import ColoredFormatter
|
||||
from danswer.utils.logger import PlainFormatter
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -116,43 +108,29 @@ def on_task_postrun(
|
||||
if task_id.startswith(RedisDocumentSet.PREFIX):
|
||||
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
|
||||
if document_set_id is not None:
|
||||
rds = RedisDocumentSet(tenant_id, int(document_set_id))
|
||||
rds = RedisDocumentSet(int(document_set_id))
|
||||
r.srem(rds.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisUserGroup.PREFIX):
|
||||
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
|
||||
if usergroup_id is not None:
|
||||
rug = RedisUserGroup(tenant_id, int(usergroup_id))
|
||||
rug = RedisUserGroup(int(usergroup_id))
|
||||
r.srem(rug.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorDelete.PREFIX):
|
||||
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
|
||||
if task_id.startswith(RedisConnectorDeletion.PREFIX):
|
||||
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
RedisConnectorDelete.remove_from_taskset(int(cc_pair_id), task_id, r)
|
||||
rcd = RedisConnectorDeletion(int(cc_pair_id))
|
||||
r.srem(rcd.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorPrune.SUBTASK_PREFIX):
|
||||
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
|
||||
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
|
||||
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
RedisConnectorPrune.remove_from_taskset(int(cc_pair_id), task_id, r)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorPermissionSync.SUBTASK_PREFIX):
|
||||
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
RedisConnectorPermissionSync.remove_from_taskset(
|
||||
int(cc_pair_id), task_id, r
|
||||
)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorExternalGroupSync.SUBTASK_PREFIX):
|
||||
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
RedisConnectorExternalGroupSync.remove_from_taskset(
|
||||
int(cc_pair_id), task_id, r
|
||||
)
|
||||
rcp = RedisConnectorPruning(int(cc_pair_id))
|
||||
r.srem(rcp.taskset_key, task_id)
|
||||
return
|
||||
|
||||
|
||||
@@ -162,154 +140,27 @@ def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None
|
||||
|
||||
|
||||
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for redis to become ready subject to a hardcoded timeout.
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
|
||||
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
ready = False
|
||||
time_start = time.monotonic()
|
||||
logger.info("Redis: Readiness probe starting.")
|
||||
logger.info("Redis: Readiness check starting.")
|
||||
while True:
|
||||
try:
|
||||
if r.ping():
|
||||
ready = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Redis: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
if not ready:
|
||||
msg = (
|
||||
f"Redis: Readiness probe did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
logger.info("Redis: Readiness probe succeeded. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def wait_for_db(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for the db to become ready subject to a hardcoded timeout.
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
ready = False
|
||||
time_start = time.monotonic()
|
||||
logger.info("Database: Readiness probe starting.")
|
||||
while True:
|
||||
try:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
result = db_session.execute(text("SELECT NOW()")).scalar()
|
||||
if result:
|
||||
ready = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Database: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
if not ready:
|
||||
msg = (
|
||||
f"Database: Readiness probe did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
logger.info("Database: Readiness probe succeeded. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def wait_for_vespa(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for Vespa to become ready subject to a hardcoded timeout.
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
ready = False
|
||||
time_start = time.monotonic()
|
||||
logger.info("Vespa: Readiness probe starting.")
|
||||
while True:
|
||||
try:
|
||||
response = requests.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
|
||||
response.raise_for_status()
|
||||
|
||||
response_dict = response.json()
|
||||
if response_dict["status"]["code"] == "up":
|
||||
ready = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Vespa: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
if not ready:
|
||||
msg = (
|
||||
f"Vespa: Readiness probe did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
logger.info("Vespa: Readiness probe succeeded. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("Running as a secondary celery worker.")
|
||||
|
||||
# Set up variables for waiting on primary worker
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
r = get_redis_client(tenant_id=None)
|
||||
time_start = time.monotonic()
|
||||
|
||||
logger.info("Waiting for primary worker to be ready...")
|
||||
while True:
|
||||
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
|
||||
break
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
logger.info(
|
||||
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
msg = (
|
||||
f"Primary worker was not ready within the timeout. "
|
||||
f"Redis: Readiness check did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
@@ -317,7 +168,57 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
logger.info("Wait for primary worker completed successfully. Continuing...")
|
||||
logger.info("Redis: Readiness check succeeded. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
logger.info("Running as a secondary celery worker.")
|
||||
logger.info("Waiting for all tenant primary workers to be ready...")
|
||||
time_start = time.monotonic()
|
||||
|
||||
while True:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
# Check if we have a primary worker lock for each tenant
|
||||
all_tenants_ready = all(
|
||||
get_redis_client(tenant_id=tenant_id).exists(
|
||||
DanswerRedisLocks.PRIMARY_WORKER
|
||||
)
|
||||
for tenant_id in tenant_ids
|
||||
)
|
||||
|
||||
if all_tenants_ready:
|
||||
break
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
ready_tenants = sum(
|
||||
1
|
||||
for tenant_id in tenant_ids
|
||||
if get_redis_client(tenant_id=tenant_id).exists(
|
||||
DanswerRedisLocks.PRIMARY_WORKER
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Not all tenant primary workers are ready yet. "
|
||||
f"Ready tenants: {ready_tenants}/{len(tenant_ids)} "
|
||||
f"elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
msg = (
|
||||
f"Not all tenant primary workers were ready within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
logger.info("All tenant primary workers are ready. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
@@ -329,20 +230,26 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
if not celery_is_worker_primary(sender):
|
||||
return
|
||||
|
||||
if not sender.primary_worker_lock:
|
||||
if not hasattr(sender, "primary_worker_locks"):
|
||||
return
|
||||
|
||||
logger.info("Releasing primary worker lock.")
|
||||
lock: RedisLock = sender.primary_worker_lock
|
||||
try:
|
||||
if lock.owned():
|
||||
try:
|
||||
lock.release()
|
||||
sender.primary_worker_lock = None
|
||||
except Exception:
|
||||
logger.exception("Failed to release primary worker lock")
|
||||
except Exception:
|
||||
logger.exception("Failed to check if primary worker lock is owned")
|
||||
for tenant_id, lock in sender.primary_worker_locks.items():
|
||||
try:
|
||||
if lock and lock.owned():
|
||||
logger.debug(f"Attempting to release lock for tenant {tenant_id}")
|
||||
try:
|
||||
lock.release()
|
||||
logger.debug(f"Successfully released lock for tenant {tenant_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to release lock for tenant {tenant_id}. Error: {str(e)}"
|
||||
)
|
||||
finally:
|
||||
sender.primary_worker_locks[tenant_id] = None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error checking lock status for tenant {tenant_id}. Error: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def on_setup_logging(
|
||||
|
||||
@@ -3,162 +3,26 @@ from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery.beat import PersistentScheduler # type: ignore
|
||||
from celery.signals import beat_init
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("danswer.background.celery.configs.beat")
|
||||
|
||||
|
||||
class DynamicTenantScheduler(PersistentScheduler):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
logger.info("Initializing DynamicTenantScheduler")
|
||||
super().__init__(*args, **kwargs)
|
||||
self._reload_interval = timedelta(minutes=2)
|
||||
self._last_reload = self.app.now() - self._reload_interval
|
||||
# Let the parent class handle store initialization
|
||||
self.setup_schedule()
|
||||
self._update_tenant_tasks()
|
||||
logger.info(f"Set reload interval to {self._reload_interval}")
|
||||
|
||||
def setup_schedule(self) -> None:
|
||||
logger.info("Setting up initial schedule")
|
||||
super().setup_schedule()
|
||||
logger.info("Initial schedule setup complete")
|
||||
|
||||
def tick(self) -> float:
|
||||
retval = super().tick()
|
||||
now = self.app.now()
|
||||
if (
|
||||
self._last_reload is None
|
||||
or (now - self._last_reload) > self._reload_interval
|
||||
):
|
||||
logger.info("Reload interval reached, initiating tenant task update")
|
||||
self._update_tenant_tasks()
|
||||
self._last_reload = now
|
||||
logger.info("Tenant task update completed, reset reload timer")
|
||||
return retval
|
||||
|
||||
def _update_tenant_tasks(self) -> None:
|
||||
logger.info("Starting tenant task update process")
|
||||
try:
|
||||
logger.info("Fetching all tenant IDs")
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
logger.info(f"Found {len(tenant_ids)} tenants")
|
||||
|
||||
logger.info("Fetching tasks to schedule")
|
||||
tasks_to_schedule = fetch_versioned_implementation(
|
||||
"danswer.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
|
||||
)
|
||||
|
||||
new_beat_schedule: dict[str, dict[str, Any]] = {}
|
||||
|
||||
current_schedule = self.schedule.items()
|
||||
|
||||
existing_tenants = set()
|
||||
for task_name, _ in current_schedule:
|
||||
if "-" in task_name:
|
||||
existing_tenants.add(task_name.split("-")[-1])
|
||||
logger.info(f"Found {len(existing_tenants)} existing tenants in schedule")
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if (
|
||||
IGNORED_SYNCING_TENANT_LIST
|
||||
and tenant_id in IGNORED_SYNCING_TENANT_LIST
|
||||
):
|
||||
logger.info(
|
||||
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
|
||||
)
|
||||
continue
|
||||
|
||||
if tenant_id not in existing_tenants:
|
||||
logger.info(f"Processing new tenant: {tenant_id}")
|
||||
|
||||
for task in tasks_to_schedule():
|
||||
task_name = f"{task['name']}-{tenant_id}"
|
||||
logger.debug(f"Creating task configuration for {task_name}")
|
||||
new_task = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"kwargs": {"tenant_id": tenant_id},
|
||||
}
|
||||
if options := task.get("options"):
|
||||
logger.debug(f"Adding options to task {task_name}: {options}")
|
||||
new_task["options"] = options
|
||||
new_beat_schedule[task_name] = new_task
|
||||
|
||||
if self._should_update_schedule(current_schedule, new_beat_schedule):
|
||||
logger.info(
|
||||
"Schedule update required",
|
||||
extra={
|
||||
"new_tasks": len(new_beat_schedule),
|
||||
"current_tasks": len(current_schedule),
|
||||
},
|
||||
)
|
||||
|
||||
# Create schedule entries
|
||||
entries = {}
|
||||
for name, entry in new_beat_schedule.items():
|
||||
entries[name] = self.Entry(
|
||||
name=name,
|
||||
app=self.app,
|
||||
task=entry["task"],
|
||||
schedule=entry["schedule"],
|
||||
options=entry.get("options", {}),
|
||||
kwargs=entry.get("kwargs", {}),
|
||||
)
|
||||
|
||||
# Update the schedule using the scheduler's methods
|
||||
self.schedule.clear()
|
||||
self.schedule.update(entries)
|
||||
|
||||
# Ensure changes are persisted
|
||||
self.sync()
|
||||
|
||||
logger.info("Schedule update completed successfully")
|
||||
else:
|
||||
logger.info("Schedule is up to date, no changes needed")
|
||||
|
||||
except (AttributeError, KeyError):
|
||||
logger.exception("Failed to process task configuration")
|
||||
except Exception:
|
||||
logger.exception("Unexpected error updating tenant tasks")
|
||||
|
||||
def _should_update_schedule(
|
||||
self, current_schedule: dict, new_schedule: dict
|
||||
) -> bool:
|
||||
"""Compare schedules to determine if an update is needed."""
|
||||
logger.debug("Comparing current and new schedules")
|
||||
current_tasks = set(name for name, _ in current_schedule)
|
||||
new_tasks = set(new_schedule.keys())
|
||||
needs_update = current_tasks != new_tasks
|
||||
logger.debug(f"Schedule update needed: {needs_update}")
|
||||
return needs_update
|
||||
|
||||
|
||||
@beat_init.connect
|
||||
def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("beat_init signal received.")
|
||||
|
||||
# Celery beat shouldn't touch the db at all. But just setting a low minimum here.
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=2, max_overflow=0)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
|
||||
|
||||
@@ -169,4 +33,68 @@ def on_setup_logging(
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
celery_app.conf.beat_scheduler = DynamicTenantScheduler
|
||||
#####
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": "check_for_vespa_sync_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": "check_for_connector_deletion_task",
|
||||
"schedule": timedelta(seconds=60),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": "check_for_indexing",
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": "check_for_pruning",
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": "kombu_message_cleanup_task",
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": "monitor_vespa_sync",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# Build the celery beat schedule dynamically
|
||||
beat_schedule = {}
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
for task in tasks_to_schedule:
|
||||
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
|
||||
beat_schedule[task_name] = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"options": task["options"],
|
||||
"kwargs": {"tenant_id": tenant_id}, # Must pass tenant_id as an argument
|
||||
}
|
||||
|
||||
# Include any existing beat schedules
|
||||
existing_beat_schedule = celery_app.conf.beat_schedule or {}
|
||||
beat_schedule.update(existing_beat_schedule)
|
||||
|
||||
# Update the Celery app configuration once
|
||||
celery_app.conf.beat_schedule = beat_schedule
|
||||
|
||||
@@ -13,7 +13,6 @@ import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -59,15 +58,9 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=4, max_overflow=12)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@@ -91,7 +84,5 @@ def on_setup_logging(
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"danswer.background.celery.tasks.pruning",
|
||||
"danswer.background.celery.tasks.doc_permission_syncing",
|
||||
"danswer.background.celery.tasks.external_group_syncing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -6,7 +6,6 @@ from celery import signals
|
||||
from celery import Task
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_process_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
@@ -14,7 +13,6 @@ import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -60,15 +58,9 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@@ -82,11 +74,6 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None:
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
|
||||
@@ -13,7 +13,6 @@ import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -60,13 +59,8 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@@ -91,7 +85,5 @@ celery_app.autodiscover_tasks(
|
||||
[
|
||||
"danswer.background.celery.tasks.shared",
|
||||
"danswer.background.celery.tasks.vespa",
|
||||
"danswer.background.celery.tasks.connector_deletion",
|
||||
"danswer.background.celery.tasks.doc_permission_syncing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import bootsteps # type: ignore
|
||||
from celery import Celery
|
||||
@@ -11,33 +10,25 @@ from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisConnectorStop
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||
from danswer.background.celery.tasks.indexing.tasks import (
|
||||
get_unfenced_index_attempt_ids,
|
||||
)
|
||||
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
|
||||
from danswer.db.engine import get_session_with_default_tenant
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_canceled
|
||||
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndex
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from danswer.redis.redis_connector_stop import RedisConnectorStop
|
||||
from danswer.redis.redis_document_set import RedisDocumentSet
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.redis.redis_usergroup import RedisUserGroup
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -84,98 +75,108 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
sender.primary_worker_locks = {}
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
# by the primary worker. This is unnecessary in the multi tenant scenario
|
||||
r = get_redis_client(tenant_id=None)
|
||||
# by the primary worker
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
for tenant_id in tenant_ids:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
|
||||
info: dict[str, Any] = cast(dict, r.info("replication"))
|
||||
role: str = cast(str, info.get("role"))
|
||||
connected_slaves: int = info.get("connected_slaves", 0)
|
||||
# For the moment, we're assuming that we are the only primary worker
|
||||
# that should be running.
|
||||
# TODO: maybe check for or clean up another zombie primary worker if we detect it
|
||||
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
|
||||
|
||||
logger.info(
|
||||
f"Redis INFO REPLICATION: role={role} connected_slaves={connected_slaves}"
|
||||
)
|
||||
# this process wide lock is taken to help other workers start up in order.
|
||||
# it is planned to use this lock to enforce singleton behavior on the primary
|
||||
# worker, since the primary worker does redis cleanup on startup, but this isn't
|
||||
# implemented yet.
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# For the moment, we're assuming that we are the only primary worker
|
||||
# that should be running.
|
||||
# TODO: maybe check for or clean up another zombie primary worker if we detect it
|
||||
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
|
||||
logger.info("Primary worker lock: Acquire starting.")
|
||||
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
|
||||
if acquired:
|
||||
logger.info("Primary worker lock: Acquire succeeded.")
|
||||
else:
|
||||
logger.error("Primary worker lock: Acquire failed!")
|
||||
raise WorkerShutdown("Primary worker lock could not be acquired!")
|
||||
|
||||
# this process wide lock is taken to help other workers start up in order.
|
||||
# it is planned to use this lock to enforce singleton behavior on the primary
|
||||
# worker, since the primary worker does redis cleanup on startup, but this isn't
|
||||
# implemented yet.
|
||||
# tacking on our own user data to the sender
|
||||
sender.primary_worker_locks[tenant_id] = lock
|
||||
|
||||
# set thread_local=False since we don't control what thread the periodic task might
|
||||
# reacquire the lock with
|
||||
lock: RedisLock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
# As currently designed, when this worker starts as "primary", we reinitialize redis
|
||||
# to a clean state (for our purposes, anyway)
|
||||
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
|
||||
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
logger.info("Primary worker lock: Acquire starting.")
|
||||
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
|
||||
if acquired:
|
||||
logger.info("Primary worker lock: Acquire succeeded.")
|
||||
else:
|
||||
logger.error("Primary worker lock: Acquire failed!")
|
||||
raise WorkerShutdown("Primary worker lock could not be acquired!")
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||
|
||||
# tacking on our own user data to the sender
|
||||
sender.primary_worker_lock = lock
|
||||
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
# As currently designed, when this worker starts as "primary", we reinitialize redis
|
||||
# to a clean state (for our purposes, anyway)
|
||||
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
|
||||
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
RedisDocumentSet.reset_all(r)
|
||||
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
RedisUserGroup.reset_all(r)
|
||||
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
RedisConnectorDelete.reset_all(r)
|
||||
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
RedisConnectorPrune.reset_all(r)
|
||||
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
RedisConnectorIndex.reset_all(r)
|
||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
RedisConnectorStop.reset_all(r)
|
||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
RedisConnectorPermissionSync.reset_all(r)
|
||||
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
RedisConnectorExternalGroupSync.reset_all(r)
|
||||
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
# mark orphaned index attempts as failed
|
||||
with get_session_with_default_tenant() as db_session:
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
|
||||
for attempt_id in unfenced_attempt_ids:
|
||||
attempt = get_index_attempt(db_session, attempt_id)
|
||||
if not attempt:
|
||||
continue
|
||||
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
failure_reason = (
|
||||
f"Canceling leftover index attempt found on startup: "
|
||||
f"index_attempt={attempt.id} "
|
||||
f"cc_pair={attempt.connector_credential_pair_id} "
|
||||
f"search_settings={attempt.search_settings_id}"
|
||||
)
|
||||
logger.warning(failure_reason)
|
||||
mark_attempt_canceled(attempt.id, db_session, failure_reason)
|
||||
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
# @worker_process_init.connect
|
||||
# def on_worker_process_init(sender: Any, **kwargs: Any) -> None:
|
||||
# """This only runs inside child processes when the worker is in pool=prefork mode.
|
||||
# This may be technically unnecessary since we're finding prefork pools to be
|
||||
# unstable and currently aren't planning on using them."""
|
||||
# logger.info("worker_process_init signal received.")
|
||||
# SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
|
||||
# SqlEngine.init_engine(pool_size=5, max_overflow=0)
|
||||
|
||||
# # https://stackoverflow.com/questions/43944787/sqlalchemy-celery-with-scoped-session-error
|
||||
# SqlEngine.get_engine().dispose(close=False)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
@@ -228,36 +229,52 @@ class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
if not celery_is_worker_primary(worker):
|
||||
return
|
||||
|
||||
if not hasattr(worker, "primary_worker_lock"):
|
||||
if not hasattr(worker, "primary_worker_locks"):
|
||||
return
|
||||
|
||||
lock: RedisLock = worker.primary_worker_lock
|
||||
# Retrieve all tenant IDs
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
r = get_redis_client(tenant_id=None)
|
||||
for tenant_id in tenant_ids:
|
||||
lock = worker.primary_worker_locks.get(tenant_id)
|
||||
if not lock:
|
||||
continue # Skip if no lock for this tenant
|
||||
|
||||
if lock.owned():
|
||||
task_logger.debug("Reacquiring primary worker lock.")
|
||||
lock.reacquire()
|
||||
else:
|
||||
task_logger.warning(
|
||||
"Full acquisition of primary worker lock. "
|
||||
"Reasons could be worker restart or lock expiration."
|
||||
)
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
task_logger.info("Primary worker lock: Acquire starting.")
|
||||
acquired = lock.acquire(
|
||||
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
|
||||
)
|
||||
if acquired:
|
||||
task_logger.info("Primary worker lock: Acquire succeeded.")
|
||||
worker.primary_worker_lock = lock
|
||||
if lock.owned():
|
||||
task_logger.debug(
|
||||
f"Reacquiring primary worker lock for tenant {tenant_id}."
|
||||
)
|
||||
lock.reacquire()
|
||||
else:
|
||||
task_logger.error("Primary worker lock: Acquire failed!")
|
||||
raise TimeoutError("Primary worker lock could not be acquired!")
|
||||
task_logger.warning(
|
||||
f"Full acquisition of primary worker lock for tenant {tenant_id}. "
|
||||
"Reasons could be worker restart or lock expiration."
|
||||
)
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Primary worker lock for tenant {tenant_id}: Acquire starting."
|
||||
)
|
||||
acquired = lock.acquire(
|
||||
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
|
||||
)
|
||||
if acquired:
|
||||
task_logger.info(
|
||||
f"Primary worker lock for tenant {tenant_id}: Acquire succeeded."
|
||||
)
|
||||
worker.primary_worker_locks[tenant_id] = lock
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Primary worker lock for tenant {tenant_id}: Acquire failed!"
|
||||
)
|
||||
raise TimeoutError(
|
||||
f"Primary worker lock for tenant {tenant_id} could not be acquired!"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Periodic task failed.")
|
||||
@@ -276,8 +293,6 @@ celery_app.autodiscover_tasks(
|
||||
"danswer.background.celery.tasks.connector_deletion",
|
||||
"danswer.background.celery.tasks.indexing",
|
||||
"danswer.background.celery.tasks.periodic",
|
||||
"danswer.background.celery.tasks.doc_permission_syncing",
|
||||
"danswer.background.celery.tasks.external_group_syncing",
|
||||
"danswer.background.celery.tasks.pruning",
|
||||
"danswer.background.celery.tasks.shared",
|
||||
"danswer.background.celery.tasks.vespa",
|
||||
|
||||
@@ -1,10 +1,568 @@
|
||||
# These are helper objects for tracking the keys we need to write in redis
|
||||
import time
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.configs.base import CELERY_SEPARATOR
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import construct_document_select_for_connector_credential_pair
|
||||
from danswer.db.document import (
|
||||
construct_document_select_for_connector_credential_pair_by_needs_sync,
|
||||
)
|
||||
from danswer.db.document_set import construct_document_select_by_docset
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
class RedisObjectHelper(ABC):
|
||||
PREFIX = "base"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: str):
|
||||
self._id: str = id
|
||||
|
||||
@property
|
||||
def task_id_prefix(self) -> str:
|
||||
return f"{self.PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def fence_key(self) -> str:
|
||||
# example: documentset_fence_1
|
||||
return f"{self.FENCE_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def taskset_key(self) -> str:
|
||||
# example: documentset_taskset_1
|
||||
return f"{self.TASKSET_PREFIX}_{self._id}"
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_fence_key(key: str) -> str | None:
|
||||
"""
|
||||
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
|
||||
|
||||
Args:
|
||||
key (str): The fence key string.
|
||||
|
||||
Returns:
|
||||
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
|
||||
"""
|
||||
parts = key.split("_")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
object_id = parts[2]
|
||||
return object_id
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_task_id(task_id: str) -> str | None:
|
||||
"""
|
||||
Extracts the object ID from a task ID string.
|
||||
|
||||
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
|
||||
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
|
||||
- `objectid` is the ID you want to extract,
|
||||
- `suffix` is another arbitrary string (e.g., a UUID).
|
||||
|
||||
Example:
|
||||
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
|
||||
this method will return the string `"1"`.
|
||||
|
||||
Args:
|
||||
task_id (str): The task ID string from which to extract the object ID.
|
||||
|
||||
Returns:
|
||||
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
|
||||
"""
|
||||
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
|
||||
parts = task_id.split("_")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
object_id = parts[1]
|
||||
return object_id
|
||||
|
||||
@abstractmethod
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
pass
|
||||
|
||||
|
||||
class RedisDocumentSet(RedisObjectHelper):
|
||||
PREFIX = "documentset"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
stmt = construct_document_select_by_docset(int(self._id), current_only=False)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisUserGroup(RedisObjectHelper):
|
||||
PREFIX = "usergroup"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
|
||||
if not global_version.is_ee_version():
|
||||
return 0
|
||||
|
||||
try:
|
||||
construct_document_select_by_usergroup = fetch_versioned_implementation(
|
||||
"danswer.db.user_group",
|
||||
"construct_document_select_by_usergroup",
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
return 0
|
||||
|
||||
stmt = construct_document_select_by_usergroup(int(self._id))
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
"""This class is used to scan documents by cc_pair in the db and collect them into
|
||||
a unified set for syncing.
|
||||
|
||||
It differs from the other redis helpers in that the taskset used spans
|
||||
all connectors and is not per connector."""
|
||||
|
||||
PREFIX = "connectorsync"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
|
||||
@classmethod
|
||||
def get_fence_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.FENCE_PREFIX
|
||||
|
||||
@classmethod
|
||||
def get_taskset_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.TASKSET_PREFIX
|
||||
|
||||
@property
|
||||
def taskset_key(self) -> str:
|
||||
"""Notice that this is intentionally reusing the same taskset for all
|
||||
connector syncs"""
|
||||
# example: connector_taskset
|
||||
return f"{self.TASKSET_PREFIX}"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(
|
||||
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
|
||||
)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorDeletion(RedisObjectHelper):
|
||||
PREFIX = "connectordeletion"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Returns None if the cc_pair doesn't exist.
|
||||
Otherwise, returns an int with the number of generated tasks."""
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
stmt = construct_document_select_for_connector_credential_pair(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
kwargs=dict(
|
||||
document_id=doc.id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorPruning(RedisObjectHelper):
|
||||
"""Celery will kick off a long running generator task to crawl the connector and
|
||||
find any missing docs, which will each then get a new cleanup task. The progress of
|
||||
those tasks will then be monitored to completion.
|
||||
|
||||
Example rough happy path order:
|
||||
Check connectorpruning_fence_1
|
||||
Send generator task with id connectorpruning+generator_1_{uuid}
|
||||
|
||||
generator runs connector with callbacks that increment connectorpruning_generator_progress_1
|
||||
generator creates many subtasks with id connectorpruning+sub_1_{uuid}
|
||||
in taskset connectorpruning_taskset_1
|
||||
on completion, generator sets connectorpruning_generator_complete_1
|
||||
|
||||
celery postrun removes subtasks from taskset
|
||||
monitor beat task cleans up when taskset reaches 0 items
|
||||
"""
|
||||
|
||||
PREFIX = "connectorpruning"
|
||||
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire pruning process
|
||||
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
|
||||
|
||||
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
|
||||
SUBTASK_PREFIX = PREFIX + "+sub"
|
||||
|
||||
GENERATOR_PROGRESS_PREFIX = (
|
||||
PREFIX + "_generator_progress"
|
||||
) # a signal that contains generator progress
|
||||
GENERATOR_COMPLETE_PREFIX = (
|
||||
PREFIX + "_generator_complete"
|
||||
) # a signal that the generator has finished
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
self.documents_to_prune: set[str] = set()
|
||||
|
||||
@property
|
||||
def generator_task_id_prefix(self) -> str:
|
||||
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_progress_key(self) -> str:
|
||||
# example: connectorpruning_generator_progress_1
|
||||
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_complete_key(self) -> str:
|
||||
# example: connectorpruning_generator_complete_1
|
||||
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def subtask_id_prefix(self) -> str:
|
||||
return f"{self.SUBTASK_PREFIX}_{self._id}"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock | None,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
for doc_id in self.documents_to_prune:
|
||||
current_time = time.monotonic()
|
||||
if lock and current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.subtask_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
kwargs=dict(
|
||||
document_id=doc_id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
def is_pruning(self, redis_client: Redis) -> bool:
|
||||
"""A single example of a helper method being refactored into the redis helper"""
|
||||
if redis_client.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class RedisConnectorIndexing(RedisObjectHelper):
|
||||
"""Celery will kick off a long running indexing task to crawl the connector and
|
||||
find any new or updated docs docs, which will each then get a new sync task or be
|
||||
indexed inline.
|
||||
|
||||
ID should be a concatenation of cc_pair_id and search_setting_id, delimited by "/".
|
||||
e.g. "2/5"
|
||||
"""
|
||||
|
||||
PREFIX = "connectorindexing"
|
||||
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
|
||||
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
|
||||
|
||||
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
|
||||
SUBTASK_PREFIX = PREFIX + "+sub"
|
||||
|
||||
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
|
||||
GENERATOR_PROGRESS_PREFIX = (
|
||||
PREFIX + "_generator_progress"
|
||||
) # a signal that contains generator progress
|
||||
GENERATOR_COMPLETE_PREFIX = (
|
||||
PREFIX + "_generator_complete"
|
||||
) # a signal that the generator has finished
|
||||
|
||||
def __init__(self, cc_pair_id: int, search_settings_id: int) -> None:
|
||||
super().__init__(f"{cc_pair_id}/{search_settings_id}")
|
||||
|
||||
@property
|
||||
def generator_lock_key(self) -> str:
|
||||
return f"{self.GENERATOR_LOCK_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_task_id_prefix(self) -> str:
|
||||
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_progress_key(self) -> str:
|
||||
# example: connectorpruning_generator_progress_1
|
||||
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_complete_key(self) -> str:
|
||||
# example: connectorpruning_generator_complete_1
|
||||
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def subtask_id_prefix(self) -> str:
|
||||
return f"{self.SUBTASK_PREFIX}_{self._id}"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock | None,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
return None
|
||||
|
||||
def is_indexing(self, redis_client: Redis) -> bool:
|
||||
"""A single example of a helper method being refactored into the redis helper"""
|
||||
if redis_client.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class RedisConnectorStop(RedisObjectHelper):
|
||||
"""Used to signal any running tasks for a connector to stop. We should refactor
|
||||
connector related redis helpers into a single class.
|
||||
"""
|
||||
|
||||
PREFIX = "connectorstop"
|
||||
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
|
||||
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock | None,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
return None
|
||||
|
||||
|
||||
def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
|
||||
@@ -4,6 +4,8 @@ from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
@@ -16,8 +18,7 @@ from danswer.connectors.models import Document
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.enums import TaskStatus
|
||||
from danswer.db.models import TaskQueueState
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.documents.models import DeletionAttemptSnapshot
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -40,14 +41,14 @@ def _get_deletion_status(
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair.id)
|
||||
if not redis_connector.delete.fenced:
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
if not r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
return TaskQueueState(
|
||||
task_id="",
|
||||
task_name=redis_connector.delete.fence_key,
|
||||
status=TaskStatus.STARTED,
|
||||
task_id="", task_name=rcd.fence_key, status=TaskStatus.STARTED
|
||||
)
|
||||
|
||||
|
||||
@@ -78,10 +79,10 @@ def document_batch_to_ids(
|
||||
|
||||
def extract_ids_from_runnable_connector(
|
||||
runnable_connector: BaseConnector,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
callback: RunIndexingCallbackInterface | None = None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
If the SlimConnector hasnt been implemented for the given connector, just pull
|
||||
If the PruneConnector hasnt been implemented for the given connector, just pull
|
||||
all docs using the load_from_state and grab out the IDs.
|
||||
|
||||
Optionally, a callback can be passed to handle the length of each document batch.
|
||||
@@ -111,15 +112,10 @@ def extract_ids_from_runnable_connector(
|
||||
for doc_batch in doc_batch_generator:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"extract_ids_from_runnable_connector: Stop signal detected"
|
||||
)
|
||||
|
||||
raise RuntimeError("Stop signal received")
|
||||
callback.progress(len(doc_batch))
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
|
||||
|
||||
if callback:
|
||||
callback.progress("extract_ids_from_runnable_connector", len(doc_batch))
|
||||
|
||||
return all_connector_doc_ids
|
||||
|
||||
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
|
||||
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_INDEXING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_PRUNING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": DanswerCeleryTask.MONITOR_VESPA_SYNC,
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return tasks_to_schedule
|
||||
@@ -1,25 +1,30 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis.lock import Lock as RedisLock
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisConnectorStop
|
||||
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
|
||||
RedisConnectorDeletionFenceData,
|
||||
)
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.search_settings import get_all_search_settings
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDeletePayload
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
|
||||
|
||||
@@ -29,7 +34,7 @@ class TaskDependencyError(RuntimeError):
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
name="check_for_connector_deletion_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
@@ -37,7 +42,7 @@ class TaskDependencyError(RuntimeError):
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -57,19 +62,19 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
|
||||
# try running cleanup on the cc_pair_ids
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
rcs = RedisConnectorStop(cc_pair_id)
|
||||
try:
|
||||
try_generate_document_cc_pair_cleanup_tasks(
|
||||
self.app, cc_pair_id, db_session, lock_beat, tenant_id
|
||||
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
except TaskDependencyError as e:
|
||||
# this means we wanted to start deleting but dependent tasks were running
|
||||
# Leave a stop signal to clear indexing and pruning tasks more quickly
|
||||
task_logger.info(str(e))
|
||||
redis_connector.stop.set_fence(True)
|
||||
r.set(rcs.fence_key, cc_pair_id)
|
||||
else:
|
||||
# clear the stop signal if it exists ... no longer needed
|
||||
redis_connector.stop.set_fence(False)
|
||||
r.delete(rcs.fence_key)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
@@ -86,7 +91,8 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
lock_beat: RedisLock,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
@@ -100,10 +106,10 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
if redis_connector.delete.fenced:
|
||||
if r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
# we need to load the state of the object inside the fence
|
||||
@@ -117,55 +123,47 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
return None
|
||||
|
||||
# set a basic fence to start
|
||||
fence_payload = RedisConnectorDeletePayload(
|
||||
fence_value = RedisConnectorDeletionFenceData(
|
||||
num_tasks=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
redis_connector.delete.set_fence(fence_payload)
|
||||
r.set(rcd.fence_key, fence_value.model_dump_json())
|
||||
|
||||
try:
|
||||
# do not proceed if connector indexing or connector pruning are running
|
||||
search_settings_list = get_all_search_settings(db_session)
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
if redis_connector_index.fenced:
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings.id)
|
||||
if r.get(rci.fence_key):
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (indexing in progress): "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
|
||||
if redis_connector.prune.fenced:
|
||||
rcp = RedisConnectorPruning(cc_pair_id)
|
||||
if r.get(rcp.fence_key):
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (pruning in progress): "
|
||||
f"cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
if redis_connector.permissions.fenced:
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (permissions in progress): "
|
||||
f"cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
redis_connector.delete.taskset_clear()
|
||||
r.delete(rcd.taskset_key)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = redis_connector.delete.generate_tasks(
|
||||
app, db_session, lock_beat
|
||||
)
|
||||
tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id)
|
||||
if tasks_generated is None:
|
||||
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
|
||||
except TaskDependencyError:
|
||||
redis_connector.delete.set_fence(None)
|
||||
r.delete(rcd.fence_key)
|
||||
raise
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception")
|
||||
redis_connector.delete.set_fence(None)
|
||||
r.delete(rcd.fence_key)
|
||||
return None
|
||||
else:
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
@@ -180,7 +178,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
fence_payload.num_tasks = tasks_generated
|
||||
redis_connector.delete.set_fence(fence_payload)
|
||||
fence_value.num_tasks = tasks_generated
|
||||
r.set(rcd.fence_key, fence_value.model_dump_json())
|
||||
|
||||
return tasks_generated
|
||||
|
||||
@@ -1,345 +0,0 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from danswer.access.models import DocExternalAccess
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import upsert_document_by_connector_credential_pair
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_doc_perm_sync import (
|
||||
RedisConnectorPermissionSyncPayload,
|
||||
)
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import doc_permission_sync_ctx
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.danswer.db.document import upsert_document_external_perms
|
||||
from ee.danswer.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
|
||||
from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES = 3
|
||||
|
||||
|
||||
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
|
||||
LIGHT_SOFT_TIME_LIMIT = 105
|
||||
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
|
||||
|
||||
def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if external doc permissions sync is due."""
|
||||
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
# skip doc permissions sync if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return False
|
||||
|
||||
# If the last sync is None, it has never been run so we run the sync
|
||||
last_perm_sync = cc_pair.last_time_perm_sync
|
||||
if last_perm_sync is None:
|
||||
return True
|
||||
|
||||
source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
|
||||
|
||||
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
|
||||
if not source_sync_period:
|
||||
return True
|
||||
|
||||
# If the last sync is greater than the full fetch period, we run the sync
|
||||
next_sync = last_perm_sync + timedelta(seconds=source_sync_period)
|
||||
if datetime.now(timezone.utc) >= next_sync:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
# get all cc pairs that need to be synced
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
if _is_external_doc_permissions_sync_due(cc_pair):
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
tasks_created = try_creating_permissions_sync_task(
|
||||
self.app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not tasks_created:
|
||||
continue
|
||||
|
||||
task_logger.info(f"Doc permissions sync queued: cc_pair={cc_pair_id}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_creating_permissions_sync_task(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Returns None if no syncing is required."""
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
try:
|
||||
if redis_connector.permissions.fenced:
|
||||
return None
|
||||
|
||||
if redis_connector.delete.fenced:
|
||||
return None
|
||||
|
||||
if redis_connector.prune.fenced:
|
||||
return None
|
||||
|
||||
redis_connector.permissions.generator_clear()
|
||||
redis_connector.permissions.taskset_clear()
|
||||
|
||||
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
|
||||
|
||||
result = app.send_task(
|
||||
DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorPermissionSyncPayload(
|
||||
started=None, celery_task_id=result.id
|
||||
)
|
||||
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return 1
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def connector_permission_sync_generator_task(
|
||||
self: Task,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
Permission sync task that handles document permission syncing for a given connector credential pair
|
||||
This task assumes that the task has already been properly fenced
|
||||
"""
|
||||
|
||||
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
|
||||
doc_permission_sync_ctx_dict["cc_pair_id"] = cc_pair_id
|
||||
doc_permission_sync_ctx_dict["request_id"] = self.request.id
|
||||
doc_permission_sync_ctx.set(doc_permission_sync_ctx_dict)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
DanswerRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"Permission sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(
|
||||
f"No connector credential pair found for id: {cc_pair_id}"
|
||||
)
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
|
||||
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if doc_sync_func is None:
|
||||
raise ValueError(
|
||||
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}")
|
||||
|
||||
payload = redis_connector.permissions.payload
|
||||
if not payload:
|
||||
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
|
||||
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = redis_connector.permissions.generate_tasks(
|
||||
celery_app=self.app,
|
||||
lock=lock,
|
||||
new_permissions=document_external_accesses,
|
||||
source_string=source_type,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.permissions.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
redis_connector.permissions.generator_complete = tasks_generated
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(f"Failed to run permission sync: cc_pair={cc_pair_id}")
|
||||
|
||||
redis_connector.permissions.generator_clear()
|
||||
redis_connector.permissions.taskset_clear()
|
||||
redis_connector.permissions.set_fence(None)
|
||||
raise e
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
|
||||
bind=True,
|
||||
)
|
||||
def update_external_document_permissions_task(
|
||||
self: Task,
|
||||
tenant_id: str | None,
|
||||
serialized_doc_external_access: dict,
|
||||
source_string: str,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> bool:
|
||||
document_external_access = DocExternalAccess.from_dict(
|
||||
serialized_doc_external_access
|
||||
)
|
||||
doc_id = document_external_access.doc_id
|
||||
external_access = document_external_access.external_access
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Add the users to the DB if they don't exist
|
||||
batch_add_ext_perm_user_if_not_exists(
|
||||
db_session=db_session,
|
||||
emails=list(external_access.external_user_emails),
|
||||
)
|
||||
# Then we upsert the document's external permissions in postgres
|
||||
created_new_doc = upsert_document_external_perms(
|
||||
db_session=db_session,
|
||||
doc_id=doc_id,
|
||||
external_access=external_access,
|
||||
source_type=DocumentSource(source_string),
|
||||
)
|
||||
|
||||
if created_new_doc:
|
||||
# If a new document was created, we associate it with the cc_pair
|
||||
upsert_document_by_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
document_ids=[doc_id],
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Successfully synced postgres document permissions for {doc_id}"
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Error Syncing Document Permissions")
|
||||
return False
|
||||
@@ -1,298 +0,0 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import mark_cc_pair_as_external_group_synced
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_ext_group_sync import (
|
||||
RedisConnectorExternalGroupSyncPayload,
|
||||
)
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.danswer.db.connector_credential_pair import get_cc_pairs_by_source
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair
|
||||
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIODS
|
||||
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
|
||||
from ee.danswer.external_permissions.sync_params import (
|
||||
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
|
||||
)
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
EXTERNAL_GROUPS_UPDATE_MAX_RETRIES = 3
|
||||
|
||||
|
||||
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
|
||||
LIGHT_SOFT_TIME_LIMIT = 105
|
||||
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
|
||||
|
||||
def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if external group sync is due."""
|
||||
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
# skip external group sync if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return False
|
||||
|
||||
# If there is not group sync function for the connector, we don't run the sync
|
||||
# This is fine because all sources dont necessarily have a concept of groups
|
||||
if not GROUP_PERMISSIONS_FUNC_MAP.get(cc_pair.connector.source):
|
||||
return False
|
||||
|
||||
# If the last sync is None, it has never been run so we run the sync
|
||||
last_ext_group_sync = cc_pair.last_time_external_group_sync
|
||||
if last_ext_group_sync is None:
|
||||
return True
|
||||
|
||||
source_sync_period = EXTERNAL_GROUP_SYNC_PERIODS.get(cc_pair.connector.source)
|
||||
|
||||
# If EXTERNAL_GROUP_SYNC_PERIODS is None, we always run the sync.
|
||||
if not source_sync_period:
|
||||
return True
|
||||
|
||||
# If the last sync is greater than the full fetch period, we run the sync
|
||||
next_sync = last_ext_group_sync + timedelta(seconds=source_sync_period)
|
||||
if datetime.now(timezone.utc) >= next_sync:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
|
||||
# We only want to sync one cc_pair per source type in
|
||||
# GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC
|
||||
for source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC:
|
||||
# These are ordered by cc_pair id so the first one is the one we want
|
||||
cc_pairs_to_dedupe = get_cc_pairs_by_source(
|
||||
db_session, source, only_sync=True
|
||||
)
|
||||
# We only want to sync one cc_pair per source type
|
||||
# in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC so we dedupe here
|
||||
for cc_pair_to_remove in cc_pairs_to_dedupe[1:]:
|
||||
cc_pairs = [
|
||||
cc_pair
|
||||
for cc_pair in cc_pairs
|
||||
if cc_pair.id != cc_pair_to_remove.id
|
||||
]
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
if _is_external_group_sync_due(cc_pair):
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
tasks_created = try_creating_external_group_sync_task(
|
||||
self.app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not tasks_created:
|
||||
continue
|
||||
|
||||
task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_creating_external_group_sync_task(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Returns None if no syncing is required."""
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
lock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_external_group_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Dont kick off a new sync if the previous one is still running
|
||||
if redis_connector.external_group_sync.fenced:
|
||||
return None
|
||||
|
||||
redis_connector.external_group_sync.generator_clear()
|
||||
redis_connector.external_group_sync.taskset_clear()
|
||||
|
||||
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
|
||||
|
||||
result = app.send_task(
|
||||
DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK,
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
payload = RedisConnectorExternalGroupSyncPayload(
|
||||
started=datetime.now(timezone.utc),
|
||||
celery_task_id=result.id,
|
||||
)
|
||||
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return 1
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK,
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def connector_external_group_sync_generator_task(
|
||||
self: Task,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
Permission sync task that handles external group syncing for a given connector credential pair
|
||||
This task assumes that the task has already been properly fenced
|
||||
"""
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(
|
||||
f"No connector credential pair found for id: {cc_pair_id}"
|
||||
)
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
|
||||
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if ext_group_sync_func is None:
|
||||
raise ValueError(
|
||||
f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
|
||||
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
|
||||
|
||||
logger.info(
|
||||
f"Syncing {len(external_user_groups)} external user groups for {source_type}"
|
||||
)
|
||||
|
||||
replace_user__ext_group_for_cc_pair(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
group_defs=external_user_groups,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
logger.info(
|
||||
f"Synced {len(external_user_groups)} external user groups for {source_type}"
|
||||
)
|
||||
|
||||
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run external group sync: cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
redis_connector.external_group_sync.generator_clear()
|
||||
redis_connector.external_group_sync.taskset_clear()
|
||||
raise e
|
||||
finally:
|
||||
# we always want to clear the fence after the task is done or failed so it doesn't get stuck
|
||||
redis_connector.external_group_sync.set_fence(None)
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
@@ -2,171 +2,99 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
import sentry_sdk
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.exceptions import LockError
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorStop
|
||||
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
|
||||
RedisConnectorIndexingFenceData,
|
||||
)
|
||||
from danswer.background.indexing.job_client import SimpleJobClient
|
||||
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.enums import IndexingMode
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.enums import IndexModelStatus
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import delete_index_attempt
|
||||
from danswer.db.index_attempt import get_all_index_attempts_by_status
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from danswer.db.index_attempt import mark_attempt_canceled
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.search_settings import get_active_search_settings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndex
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndexPayload
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class IndexingCallback(IndexingHeartbeatInterface):
|
||||
class RunIndexingCallback(RunIndexingCallbackInterface):
|
||||
def __init__(
|
||||
self,
|
||||
stop_key: str,
|
||||
generator_progress_key: str,
|
||||
redis_lock: RedisLock,
|
||||
redis_lock: redis.lock.Lock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.redis_lock: RedisLock = redis_lock
|
||||
self.redis_lock: redis.lock.Lock = redis_lock
|
||||
self.stop_key: str = stop_key
|
||||
self.generator_progress_key: str = generator_progress_key
|
||||
self.redis_client = redis_client
|
||||
self.started: datetime = datetime.now(timezone.utc)
|
||||
self.redis_lock.reacquire()
|
||||
|
||||
self.last_tag: str = "IndexingCallback.__init__"
|
||||
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_client.exists(self.stop_key):
|
||||
return True
|
||||
return False
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
try:
|
||||
self.redis_lock.reacquire()
|
||||
self.last_tag = tag
|
||||
self.last_lock_reacquire = datetime.now(timezone.utc)
|
||||
except LockError:
|
||||
logger.exception(
|
||||
f"IndexingCallback - lock.reacquire exceptioned. "
|
||||
f"lock_timeout={self.redis_lock.timeout} "
|
||||
f"start={self.started} "
|
||||
f"last_tag={self.last_tag} "
|
||||
f"last_reacquired={self.last_lock_reacquire} "
|
||||
f"now={datetime.now(timezone.utc)}"
|
||||
)
|
||||
raise
|
||||
|
||||
def progress(self, amount: int) -> None:
|
||||
self.redis_lock.reacquire()
|
||||
self.redis_client.incrby(self.generator_progress_key, amount)
|
||||
|
||||
|
||||
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
|
||||
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
|
||||
want to clean them up.
|
||||
|
||||
Unfenced = attempt not in terminal state and fence does not exist.
|
||||
"""
|
||||
unfenced_attempts: list[int] = []
|
||||
|
||||
# inner/outer/inner double check pattern to avoid race conditions when checking for
|
||||
# bad state
|
||||
# inner = index_attempt in non terminal state
|
||||
# outer = r.fence_key down
|
||||
|
||||
# check the db for index attempts in a non terminal state
|
||||
attempts: list[IndexAttempt] = []
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
|
||||
)
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
|
||||
)
|
||||
|
||||
for attempt in attempts:
|
||||
fence_key = RedisConnectorIndex.fence_key_with_ids(
|
||||
attempt.connector_credential_pair_id, attempt.search_settings_id
|
||||
)
|
||||
|
||||
# if the fence is down / doesn't exist, possible error but not confirmed
|
||||
if r.exists(fence_key):
|
||||
continue
|
||||
|
||||
# Between the time the attempts are first looked up and the time we see the fence down,
|
||||
# the attempt may have completed and taken down the fence normally.
|
||||
|
||||
# We need to double check that the index attempt is still in a non terminal state
|
||||
# and matches the original state, which confirms we are really in a bad state.
|
||||
attempt_2 = get_index_attempt(db_session, attempt.id)
|
||||
if not attempt_2:
|
||||
continue
|
||||
|
||||
if attempt.status != attempt_2.status:
|
||||
continue
|
||||
|
||||
unfenced_attempts.append(attempt.id)
|
||||
|
||||
return unfenced_attempts
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_INDEXING,
|
||||
name="check_for_indexing",
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
tasks_created = 0
|
||||
locked = False
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -176,49 +104,44 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
locked = True
|
||||
|
||||
# check for search settings swap
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
old_search_settings = check_index_swap(db_session=db_session)
|
||||
check_index_swap(db_session=db_session)
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
if current_search_settings.provider_type is None and not MULTI_TENANT:
|
||||
if old_search_settings:
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=current_search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
)
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=current_search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
)
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
# only warm up if search settings were changed
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
# gather cc_pair_ids
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
lock_beat.reacquire()
|
||||
cc_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
# kick off index attempts
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
search_settings_list: list[SearchSettings] = get_active_search_settings(
|
||||
db_session
|
||||
)
|
||||
for search_settings_instance in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings_instance.id
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
search_settings = [primary_search_settings]
|
||||
|
||||
# Check for secondary search settings
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings is not None:
|
||||
# If secondary settings exist, add them to the list
|
||||
search_settings.append(secondary_search_settings)
|
||||
|
||||
for search_settings_instance in search_settings:
|
||||
rci = RedisConnectorIndexing(
|
||||
cc_pair_id, search_settings_instance.id
|
||||
)
|
||||
if redis_connector_index.fenced:
|
||||
if r.exists(rci.fence_key):
|
||||
continue
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
@@ -230,80 +153,31 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
|
||||
search_settings_primary = False
|
||||
if search_settings_instance.id == search_settings_list[0].id:
|
||||
search_settings_primary = True
|
||||
|
||||
if not _should_index(
|
||||
cc_pair=cc_pair,
|
||||
last_index=last_attempt,
|
||||
search_settings_instance=search_settings_instance,
|
||||
search_settings_primary=search_settings_primary,
|
||||
secondary_index_building=len(search_settings_list) > 1,
|
||||
secondary_index_building=len(search_settings) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
reindex = False
|
||||
if search_settings_instance.id == search_settings_list[0].id:
|
||||
# the indexing trigger is only checked and cleared with the primary search settings
|
||||
if cc_pair.indexing_trigger is not None:
|
||||
if cc_pair.indexing_trigger == IndexingMode.REINDEX:
|
||||
reindex = True
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing manual trigger detected: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings_instance.id} "
|
||||
f"indexing_mode={cc_pair.indexing_trigger}"
|
||||
)
|
||||
|
||||
mark_ccpair_with_indexing_trigger(
|
||||
cc_pair.id, None, db_session
|
||||
)
|
||||
|
||||
# using a task queue and only allowing one task per cc_pair/search_setting
|
||||
# prevents us from starving out certain attempts
|
||||
attempt_id = try_creating_indexing_task(
|
||||
self.app,
|
||||
cc_pair,
|
||||
search_settings_instance,
|
||||
reindex,
|
||||
False,
|
||||
db_session,
|
||||
r,
|
||||
tenant_id,
|
||||
)
|
||||
if attempt_id:
|
||||
task_logger.info(
|
||||
f"Connector indexing queued: "
|
||||
f"index_attempt={attempt_id} "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings_instance.id}"
|
||||
f"Indexing queued: cc_pair={cc_pair.id} index_attempt={attempt_id}"
|
||||
)
|
||||
tasks_created += 1
|
||||
|
||||
# Fail any index attempts in the DB that don't have fences
|
||||
# This shouldn't ever happen!
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
|
||||
for attempt_id in unfenced_attempt_ids:
|
||||
lock_beat.reacquire()
|
||||
|
||||
attempt = get_index_attempt(db_session, attempt_id)
|
||||
if not attempt:
|
||||
continue
|
||||
|
||||
failure_reason = (
|
||||
f"Unfenced index attempt found in DB: "
|
||||
f"index_attempt={attempt.id} "
|
||||
f"cc_pair={attempt.connector_credential_pair_id} "
|
||||
f"search_settings={attempt.search_settings_id}"
|
||||
)
|
||||
task_logger.error(failure_reason)
|
||||
mark_attempt_failed(
|
||||
attempt.id, db_session, failure_reason=failure_reason
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -311,14 +185,8 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
finally:
|
||||
if locked:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"check_for_indexing - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return tasks_created
|
||||
|
||||
@@ -327,7 +195,6 @@ def _should_index(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
last_index: IndexAttempt | None,
|
||||
search_settings_instance: SearchSettings,
|
||||
search_settings_primary: bool,
|
||||
secondary_index_building: bool,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
@@ -392,11 +259,6 @@ def _should_index(
|
||||
):
|
||||
return False
|
||||
|
||||
if search_settings_primary:
|
||||
if cc_pair.indexing_trigger is not None:
|
||||
# if a manual indexing trigger is on the cc pair, honor it for primary search settings
|
||||
return True
|
||||
|
||||
# if no attempt has ever occurred, we should index regardless of refresh_freq
|
||||
if not last_index:
|
||||
return True
|
||||
@@ -429,11 +291,10 @@ def try_creating_indexing_task(
|
||||
"""
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
index_attempt_id: int | None = None
|
||||
|
||||
# we need to serialize any attempt to trigger indexing since it can be triggered
|
||||
# either via celery beat or manually (API call)
|
||||
lock: RedisLock = r.lock(
|
||||
lock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -443,15 +304,15 @@ def try_creating_indexing_task(
|
||||
return None
|
||||
|
||||
try:
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair.id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
rci = RedisConnectorIndexing(cc_pair.id, search_settings.id)
|
||||
|
||||
# skip if already indexing
|
||||
if redis_connector_index.fenced:
|
||||
if r.exists(rci.fence_key):
|
||||
return None
|
||||
|
||||
# skip indexing if the cc_pair is deleting
|
||||
if redis_connector.delete.fenced:
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
if r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
db_session.refresh(cc_pair)
|
||||
@@ -459,17 +320,19 @@ def try_creating_indexing_task(
|
||||
return None
|
||||
|
||||
# add a long running generator task to the queue
|
||||
redis_connector_index.generator_clear()
|
||||
r.delete(rci.generator_complete_key)
|
||||
r.delete(rci.taskset_key)
|
||||
|
||||
custom_task_id = f"{rci.generator_task_id_prefix}_{uuid4()}"
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorIndexPayload(
|
||||
fence_value = RedisConnectorIndexingFenceData(
|
||||
index_attempt_id=None,
|
||||
started=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
celery_task_id=None,
|
||||
)
|
||||
|
||||
redis_connector_index.set_fence(payload)
|
||||
r.set(rci.fence_key, fence_value.model_dump_json())
|
||||
|
||||
# create the index attempt for tracking purposes
|
||||
# code elsewhere checks for index attempts without an associated redis key
|
||||
@@ -482,12 +345,8 @@ def try_creating_indexing_task(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
custom_task_id = redis_connector_index.generate_generator_task_id()
|
||||
|
||||
# when the task is sent, we have yet to finish setting up the fence
|
||||
# therefore, the task must contain code that blocks until the fence is ready
|
||||
result = celery_app.send_task(
|
||||
DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
|
||||
"connector_indexing_proxy_task",
|
||||
kwargs=dict(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
@@ -502,20 +361,17 @@ def try_creating_indexing_task(
|
||||
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
|
||||
|
||||
# now fill out the fence with the rest of the data
|
||||
payload.index_attempt_id = index_attempt_id
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector_index.set_fence(payload)
|
||||
fence_value.index_attempt_id = index_attempt_id
|
||||
fence_value.celery_task_id = result.id
|
||||
r.set(rci.fence_key, fence_value.model_dump_json())
|
||||
except Exception:
|
||||
r.delete(rci.fence_key)
|
||||
task_logger.exception(
|
||||
f"try_creating_indexing_task - Unexpected exception: "
|
||||
f"Unexpected exception: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
|
||||
if index_attempt_id is not None:
|
||||
delete_index_attempt(db_session, index_attempt_id)
|
||||
redis_connector_index.set_fence(None)
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
@@ -524,34 +380,19 @@ def try_creating_indexing_task(
|
||||
return index_attempt_id
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
|
||||
bind=True,
|
||||
acks_late=False,
|
||||
track_started=True,
|
||||
)
|
||||
@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True)
|
||||
def connector_indexing_proxy_task(
|
||||
self: Task,
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - starting: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
if not self.request.id:
|
||||
task_logger.error("self.request.id is None!")
|
||||
|
||||
client = SimpleJobClient()
|
||||
|
||||
job = client.submit(
|
||||
connector_indexing_task_wrapper,
|
||||
connector_indexing_task,
|
||||
index_attempt_id,
|
||||
cc_pair_id,
|
||||
search_settings_id,
|
||||
@@ -561,142 +402,32 @@ def connector_indexing_proxy_task(
|
||||
)
|
||||
|
||||
if not job:
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - spawn failed: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
while True:
|
||||
sleep(5)
|
||||
|
||||
if self.request.id and redis_connector_index.terminating(self.request.id):
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - termination signal detected: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
sleep(10)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Connector termination signal detected",
|
||||
)
|
||||
finally:
|
||||
# if the DB exceptions, we'll just get an unfriendly failure message
|
||||
# in the UI instead of the cancellation message
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception marking index attempt as canceled: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if not job.done():
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
job.cancel()
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
logger.error(job.exception())
|
||||
|
||||
job.release()
|
||||
break
|
||||
|
||||
if not job.done():
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
except Exception:
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
exit_code: int | None = None
|
||||
if job.process:
|
||||
exit_code = job.process.exitcode
|
||||
task_logger.error(
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code} "
|
||||
f"error={job.exception()}"
|
||||
)
|
||||
|
||||
job.release()
|
||||
break
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - finished: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def connector_indexing_task_wrapper(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str | None,
|
||||
is_ee: bool,
|
||||
) -> int | None:
|
||||
"""Just wraps connector_indexing_task so we can log any exceptions before
|
||||
re-raising it."""
|
||||
result: int | None = None
|
||||
|
||||
try:
|
||||
result = connector_indexing_task(
|
||||
index_attempt_id,
|
||||
cc_pair_id,
|
||||
search_settings_id,
|
||||
tenant_id,
|
||||
is_ee,
|
||||
)
|
||||
except:
|
||||
logger.exception(
|
||||
f"connector_indexing_task exceptioned: "
|
||||
f"tenant={tenant_id} "
|
||||
f"index_attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
raise
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def connector_indexing_task(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
@@ -715,102 +446,78 @@ def connector_indexing_task(
|
||||
|
||||
Returns None if the task did not run (possibly due to a conflict).
|
||||
Otherwise, returns an int >= 0 representing the number of indexed docs.
|
||||
|
||||
NOTE: if an exception is raised out of this task, the primary worker will detect
|
||||
that the task transitioned to a "READY" state but the generator_complete_key doesn't exist.
|
||||
This will cause the primary worker to abort the indexing attempt and clean up.
|
||||
"""
|
||||
|
||||
# Since connector_indexing_proxy_task spawns a new process using this function as
|
||||
# the entrypoint, we init Sentry here.
|
||||
if SENTRY_DSN:
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
traces_sample_rate=0.1,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task starting: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
attempt_found = False
|
||||
n_final_progress: int | None = None
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
attempt = None
|
||||
n_final_progress = 0
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if redis_connector.delete.fenced:
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
if r.exists(rcd.fence_key):
|
||||
raise RuntimeError(
|
||||
f"Indexing will not start because connector deletion is in progress: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={redis_connector.delete.fence_key}"
|
||||
f"fence={rcd.fence_key}"
|
||||
)
|
||||
|
||||
if redis_connector.stop.fenced:
|
||||
rcs = RedisConnectorStop(cc_pair_id)
|
||||
if r.exists(rcs.fence_key):
|
||||
raise RuntimeError(
|
||||
f"Indexing will not start because a connector stop signal was detected: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={redis_connector.stop.fence_key}"
|
||||
f"fence={rcs.fence_key}"
|
||||
)
|
||||
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
|
||||
|
||||
while True:
|
||||
if not redis_connector_index.fenced: # The fence must exist
|
||||
# read related data and evaluate/print task progress
|
||||
fence_value = cast(bytes, r.get(rci.fence_key))
|
||||
if fence_value is None:
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
|
||||
f"connector_indexing_task: fence_value not found: fence={rci.fence_key}"
|
||||
)
|
||||
|
||||
payload = redis_connector_index.payload # The payload must exist
|
||||
if not payload:
|
||||
raise ValueError("connector_indexing_task: payload invalid or not found")
|
||||
try:
|
||||
fence_json = fence_value.decode("utf-8")
|
||||
fence_data = RedisConnectorIndexingFenceData.model_validate_json(
|
||||
cast(str, fence_json)
|
||||
)
|
||||
except ValueError:
|
||||
task_logger.exception(
|
||||
f"connector_indexing_task: fence_data not decodeable: fence={rci.fence_key}"
|
||||
)
|
||||
raise
|
||||
|
||||
if payload.index_attempt_id is None or payload.celery_task_id is None:
|
||||
logger.info(
|
||||
f"connector_indexing_task - Waiting for fence: fence={redis_connector_index.fence_key}"
|
||||
if fence_data.index_attempt_id is None or fence_data.celery_task_id is None:
|
||||
task_logger.info(
|
||||
f"connector_indexing_task - Waiting for fence: fence={rci.fence_key}"
|
||||
)
|
||||
sleep(1)
|
||||
continue
|
||||
|
||||
if payload.index_attempt_id != index_attempt_id:
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - id mismatch. Task may be left over from previous run.: "
|
||||
f"task_index_attempt={index_attempt_id} "
|
||||
f"payload_index_attempt={payload.index_attempt_id}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"connector_indexing_task - Fence found, continuing...: fence={redis_connector_index.fence_key}"
|
||||
task_logger.info(
|
||||
f"connector_indexing_task - Fence found, continuing...: fence={rci.fence_key}"
|
||||
)
|
||||
break
|
||||
|
||||
# set thread_local=False since we don't control what thread the indexing/pruning
|
||||
# might run our callback with
|
||||
lock: RedisLock = r.lock(
|
||||
redis_connector_index.generator_lock_key,
|
||||
lock = r.lock(
|
||||
rci.generator_lock_key,
|
||||
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
task_logger.warning(
|
||||
f"Indexing task already running, exiting...: "
|
||||
f"index_attempt={index_attempt_id} cc_pair={cc_pair_id} search_settings={search_settings_id}"
|
||||
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
|
||||
)
|
||||
# r.set(rci.generator_complete_key, HTTPStatus.CONFLICT.value)
|
||||
return None
|
||||
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector_index.set_fence(payload)
|
||||
fence_data.started = datetime.now(timezone.utc)
|
||||
r.set(rci.fence_key, fence_data.model_dump_json())
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -819,7 +526,6 @@ def connector_indexing_task(
|
||||
raise ValueError(
|
||||
f"Index attempt not found: index_attempt={index_attempt_id}"
|
||||
)
|
||||
attempt_found = True
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -839,52 +545,43 @@ def connector_indexing_task(
|
||||
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}"
|
||||
)
|
||||
|
||||
# define a callback class
|
||||
callback = IndexingCallback(
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector_index.generator_progress_key,
|
||||
lock,
|
||||
r,
|
||||
)
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task running entrypoint: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
# define a callback class
|
||||
callback = RunIndexingCallback(
|
||||
rcs.fence_key, rci.generator_progress_key, lock, r
|
||||
)
|
||||
|
||||
run_indexing_entrypoint(
|
||||
index_attempt_id,
|
||||
tenant_id,
|
||||
cc_pair_id,
|
||||
is_ee,
|
||||
callback=callback,
|
||||
)
|
||||
run_indexing_entrypoint(
|
||||
index_attempt_id,
|
||||
tenant_id,
|
||||
cc_pair_id,
|
||||
is_ee,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
# get back the total number of indexed docs and return it
|
||||
n_final_progress = redis_connector_index.get_progress()
|
||||
redis_connector_index.set_generator_complete(HTTPStatus.OK.value)
|
||||
# get back the total number of indexed docs and return it
|
||||
generator_progress_value = r.get(rci.generator_progress_key)
|
||||
if generator_progress_value is not None:
|
||||
try:
|
||||
n_final_progress = int(cast(int, generator_progress_value))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
r.set(rci.generator_complete_key, HTTPStatus.OK.value)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Indexing spawned task failed: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
if attempt_found:
|
||||
task_logger.exception(f"Indexing failed: cc_pair={cc_pair_id}")
|
||||
if attempt:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_failed(index_attempt_id, db_session, failure_reason=str(e))
|
||||
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
|
||||
|
||||
r.delete(rci.generator_lock_key)
|
||||
r.delete(rci.generator_progress_key)
|
||||
r.delete(rci.taskset_key)
|
||||
r.delete(rci.fence_key)
|
||||
raise e
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task finished: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
return n_final_progress
|
||||
|
||||
@@ -13,13 +13,12 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import PostgresAdvisoryLocks
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
|
||||
name="kombu_message_cleanup_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
base=AbortableTask,
|
||||
|
||||
@@ -8,12 +8,14 @@ from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisConnectorStop
|
||||
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from danswer.background.celery.tasks.indexing.tasks import IndexingCallback
|
||||
from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
|
||||
@@ -21,7 +23,6 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import InputType
|
||||
@@ -32,52 +33,14 @@ from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import pruning_ctx
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if pruning is due.
|
||||
|
||||
Next pruning time is calculated as a delta from the last successful prune, or the
|
||||
last successful indexing if pruning has never succeeded.
|
||||
|
||||
TODO(rkuo): consider whether we should allow pruning to be immediately rescheduled
|
||||
if pruning fails (which is what it does now). A backoff could be reasonable.
|
||||
"""
|
||||
|
||||
# skip pruning if no prune frequency is set
|
||||
# pruning can still be forced via the API which will run a pruning task directly
|
||||
if not cc_pair.connector.prune_freq:
|
||||
return False
|
||||
|
||||
# skip pruning if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
# skip pruning if the next scheduled prune time hasn't been reached yet
|
||||
last_pruned = cc_pair.last_pruned
|
||||
if not last_pruned:
|
||||
if not cc_pair.last_successful_index_time:
|
||||
# if we've never indexed, we can't prune
|
||||
return False
|
||||
|
||||
# if never pruned, use the last time the connector indexed successfully
|
||||
last_pruned = cc_pair.last_successful_index_time
|
||||
|
||||
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
|
||||
if datetime.now(timezone.utc) < next_prune:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_PRUNING,
|
||||
name="check_for_pruning",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
@@ -107,7 +70,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
if not cc_pair:
|
||||
continue
|
||||
|
||||
if not _is_pruning_due(cc_pair):
|
||||
if not is_pruning_due(cc_pair, db_session, r):
|
||||
continue
|
||||
|
||||
tasks_created = try_creating_prune_generator_task(
|
||||
@@ -128,6 +91,47 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def is_pruning_due(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
) -> bool:
|
||||
"""Returns an int if pruning is triggered.
|
||||
The int represents the number of prune tasks generated (in this case, only one
|
||||
because the task is a long running generator task.)
|
||||
Returns None if no pruning is triggered (due to not being needed or
|
||||
other reasons such as simultaneous pruning restrictions.
|
||||
|
||||
Checks for scheduling related conditions, then delegates the rest of the checks to
|
||||
try_creating_prune_generator_task.
|
||||
"""
|
||||
|
||||
# skip pruning if no prune frequency is set
|
||||
# pruning can still be forced via the API which will run a pruning task directly
|
||||
if not cc_pair.connector.prune_freq:
|
||||
return False
|
||||
|
||||
# skip pruning if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
# skip pruning if the next scheduled prune time hasn't been reached yet
|
||||
last_pruned = cc_pair.last_pruned
|
||||
if not last_pruned:
|
||||
if not cc_pair.last_successful_index_time:
|
||||
# if we've never indexed, we can't prune
|
||||
return False
|
||||
|
||||
# if never pruned, use the last time the connector indexed successfully
|
||||
last_pruned = cc_pair.last_successful_index_time
|
||||
|
||||
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
|
||||
if datetime.now(timezone.utc) < next_prune:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_prune_generator_task(
|
||||
celery_app: Celery,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
@@ -142,11 +146,8 @@ def try_creating_prune_generator_task(
|
||||
is used to trigger prunes immediately, e.g. via the web ui.
|
||||
"""
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair.id)
|
||||
|
||||
if not ALLOW_SIMULTANEOUS_PRUNING:
|
||||
count = redis_connector.prune.get_active_task_count()
|
||||
if count > 0:
|
||||
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
return None
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
@@ -163,16 +164,15 @@ def try_creating_prune_generator_task(
|
||||
return None
|
||||
|
||||
try:
|
||||
rcp = RedisConnectorPruning(cc_pair.id)
|
||||
|
||||
# skip pruning if already pruning
|
||||
if redis_connector.prune.fenced:
|
||||
if r.exists(rcp.fence_key):
|
||||
return None
|
||||
|
||||
# skip pruning if the cc_pair is deleting
|
||||
if redis_connector.delete.fenced:
|
||||
return None
|
||||
|
||||
# skip pruning if doc permissions sync is running
|
||||
if redis_connector.permissions.fenced:
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
if r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
db_session.refresh(cc_pair)
|
||||
@@ -180,13 +180,13 @@ def try_creating_prune_generator_task(
|
||||
return None
|
||||
|
||||
# add a long running generator task to the queue
|
||||
redis_connector.prune.generator_clear()
|
||||
redis_connector.prune.taskset_clear()
|
||||
r.delete(rcp.generator_complete_key)
|
||||
r.delete(rcp.taskset_key)
|
||||
|
||||
custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}"
|
||||
custom_task_id = f"{rcp.generator_task_id_prefix}_{uuid4()}"
|
||||
|
||||
celery_app.send_task(
|
||||
DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
|
||||
"connector_pruning_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair.id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
@@ -199,7 +199,7 @@ def try_creating_prune_generator_task(
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
redis_connector.prune.set_fence(True)
|
||||
r.set(rcp.fence_key, 1)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}")
|
||||
return None
|
||||
@@ -211,7 +211,7 @@ def try_creating_prune_generator_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
|
||||
name="connector_pruning_generator_task",
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
@@ -229,23 +229,13 @@ def connector_pruning_generator_task(
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list"""
|
||||
|
||||
pruning_ctx_dict = pruning_ctx.get()
|
||||
pruning_ctx_dict["cc_pair_id"] = cc_pair_id
|
||||
pruning_ctx_dict["request_id"] = self.request.id
|
||||
pruning_ctx.set(pruning_ctx_dict)
|
||||
|
||||
task_logger.info(f"Pruning generator starting: cc_pair={cc_pair_id}")
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# set thread_local=False since we don't control what thread the indexing/pruning
|
||||
# might run our callback with
|
||||
lock: RedisLock = r.lock(
|
||||
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
|
||||
rcp = RedisConnectorPruning(cc_pair_id)
|
||||
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{rcp._id}",
|
||||
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
@@ -269,11 +259,6 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning generator running connector: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector_source={cc_pair.connector.source}"
|
||||
)
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session,
|
||||
cc_pair.connector.source,
|
||||
@@ -282,13 +267,11 @@ def connector_pruning_generator_task(
|
||||
cc_pair.credential,
|
||||
)
|
||||
|
||||
callback = IndexingCallback(
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector.prune.generator_progress_key,
|
||||
lock,
|
||||
r,
|
||||
)
|
||||
rcs = RedisConnectorStop(cc_pair_id)
|
||||
|
||||
callback = RunIndexingCallback(
|
||||
rcs.fence_key, rcp.generator_progress_key, lock, r
|
||||
)
|
||||
# a list of docs in the source
|
||||
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
|
||||
runnable_connector, callback
|
||||
@@ -310,34 +293,36 @@ def connector_pruning_generator_task(
|
||||
task_logger.info(
|
||||
f"Pruning set collected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector_source={cc_pair.connector.source} "
|
||||
f"docs_to_remove={len(doc_ids_to_remove)}"
|
||||
f"docs_to_remove={len(doc_ids_to_remove)} "
|
||||
f"doc_source={cc_pair.connector.source}"
|
||||
)
|
||||
|
||||
rcp.documents_to_prune = set(doc_ids_to_remove)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.prune.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
f"RedisConnectorPruning.generate_tasks starting. cc_pair={cc_pair.id}"
|
||||
)
|
||||
tasks_generated = redis_connector.prune.generate_tasks(
|
||||
set(doc_ids_to_remove), self.app, db_session, None
|
||||
tasks_generated = rcp.generate_tasks(
|
||||
self.app, db_session, r, None, tenant_id
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.prune.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
f"RedisConnectorPruning.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
redis_connector.prune.generator_complete = tasks_generated
|
||||
r.set(rcp.generator_complete_key, tasks_generated)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
|
||||
)
|
||||
|
||||
redis_connector.prune.reset()
|
||||
r.delete(rcp.generator_progress_key)
|
||||
r.delete(rcp.taskset_key)
|
||||
r.delete(rcp.fence_key)
|
||||
raise e
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
task_logger.info(f"Pruning generator finished: cc_pair={cc_pair_id}")
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RedisConnectorDeletionFenceData(BaseModel):
|
||||
num_tasks: int | None
|
||||
submitted: datetime
|
||||
@@ -0,0 +1,10 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RedisConnectorIndexingFenceData(BaseModel):
|
||||
index_attempt_id: int | None
|
||||
started: datetime | None
|
||||
submitted: datetime
|
||||
celery_task_id: str | None
|
||||
@@ -9,7 +9,6 @@ from tenacity import RetryError
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document
|
||||
@@ -32,7 +31,7 @@ LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
|
||||
name="document_by_cc_pair_cleanup_task",
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES,
|
||||
@@ -60,7 +59,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
connector / credential pair from the access list
|
||||
(6) delete all relevant entries from postgres
|
||||
"""
|
||||
task_logger.debug(f"Task start: tenant={tenant_id} doc={document_id}")
|
||||
task_logger.info(f"tenant={tenant_id} doc={document_id}")
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -142,9 +141,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
return False
|
||||
except Exception as ex:
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(
|
||||
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
|
||||
)
|
||||
task_logger.info(f"Retry failed: {ex.last_attempt.attempt_number}")
|
||||
|
||||
# only set the inner exception if it is of type Exception
|
||||
e_temp = ex.last_attempt.exception()
|
||||
@@ -174,21 +171,11 @@ def document_by_cc_pair_cleanup_task(
|
||||
else:
|
||||
# This is the last attempt! mark the document as dirty in the db so that it
|
||||
# eventually gets fixed out of band via stale document reconciliation
|
||||
task_logger.warning(
|
||||
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
|
||||
task_logger.info(
|
||||
f"Max retries reached. Marking doc as dirty for reconciliation: "
|
||||
f"tenant={tenant_id} doc={document_id}"
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# delete the cc pair relationship now and let reconciliation clean it up
|
||||
# in vespa
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_id=document_id,
|
||||
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
with get_session_with_tenant(tenant_id):
|
||||
mark_document_as_modified(document_id, db_session)
|
||||
return False
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from http import HTTPStatus
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
import redis
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
@@ -12,23 +13,32 @@ from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.result import AsyncResult
|
||||
from celery.states import READY_STATES
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
from tenacity import RetryError
|
||||
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import celery_get_queue_length
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
|
||||
RedisConnectorDeletionFenceData,
|
||||
)
|
||||
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
|
||||
RedisConnectorIndexingFenceData,
|
||||
)
|
||||
from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from danswer.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from danswer.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from danswer.db.connector import mark_ccpair_as_pruned
|
||||
from danswer.db.connector_credential_pair import add_deletion_failure_message
|
||||
from danswer.db.connector_credential_pair import (
|
||||
@@ -49,24 +59,15 @@ from danswer.db.document_set import mark_document_set_as_synced
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.index_attempt import delete_index_attempts
|
||||
from danswer.db.index_attempt import get_all_index_attempts_by_status
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import VespaDocumentFields
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from danswer.redis.redis_connector_doc_perm_sync import (
|
||||
RedisConnectorPermissionSyncPayload,
|
||||
)
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndex
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from danswer.redis.redis_document_set import RedisDocumentSet
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.redis.redis_usergroup import RedisUserGroup
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import (
|
||||
@@ -81,7 +82,7 @@ logger = setup_logger()
|
||||
# celery auto associates tasks created inside another task,
|
||||
# which bloats the result metadata considerably. trail=False prevents this.
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
name="check_for_vespa_sync_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
@@ -166,7 +167,7 @@ def try_generate_stale_document_sync_tasks(
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
lock_beat: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
# the fence is up, do nothing
|
||||
@@ -184,34 +185,30 @@ def try_generate_stale_document_sync_tasks(
|
||||
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
"RedisConnector.generate_tasks starting by cc_pair. "
|
||||
"Documents spanning multiple cc_pairs will only be synced once."
|
||||
)
|
||||
|
||||
docs_to_skip: set[str] = set()
|
||||
task_logger.info("RedisConnector.generate_tasks starting by cc_pair.")
|
||||
|
||||
# rkuo: we could technically sync all stale docs in one big pass.
|
||||
# but I feel it's more understandable to group the docs by cc_pair
|
||||
total_tasks_generated = 0
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
|
||||
rc.set_skip_docs(docs_to_skip)
|
||||
result = rc.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
|
||||
rc = RedisConnectorCredentialPair(cc_pair.id)
|
||||
tasks_generated = rc.generate_tasks(
|
||||
celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
|
||||
if result is None:
|
||||
if tasks_generated is None:
|
||||
continue
|
||||
|
||||
if result[1] == 0:
|
||||
if tasks_generated == 0:
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks finished for single cc_pair. "
|
||||
f"cc_pair={cc_pair.id} tasks_generated={result[0]} tasks_possible={result[1]}"
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
total_tasks_generated += result[0]
|
||||
total_tasks_generated += tasks_generated
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
|
||||
@@ -226,15 +223,15 @@ def try_generate_document_set_sync_tasks(
|
||||
document_set_id: int,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
lock_beat: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rds = RedisDocumentSet(tenant_id, document_set_id)
|
||||
rds = RedisDocumentSet(document_set_id)
|
||||
|
||||
# don't generate document set sync tasks if tasks are still pending
|
||||
if rds.fenced:
|
||||
if r.exists(rds.fence_key):
|
||||
return None
|
||||
|
||||
# don't generate sync tasks if we're up to date
|
||||
@@ -254,11 +251,12 @@ def try_generate_document_set_sync_tasks(
|
||||
)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
result = rds.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
|
||||
if result is None:
|
||||
tasks_generated = rds.generate_tasks(
|
||||
celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
tasks_generated = result[0]
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
@@ -267,11 +265,11 @@ def try_generate_document_set_sync_tasks(
|
||||
|
||||
task_logger.info(
|
||||
f"RedisDocumentSet.generate_tasks finished. "
|
||||
f"document_set={document_set.id} tasks_generated={tasks_generated}"
|
||||
f"document_set_id={document_set.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
rds.set_fence(tasks_generated)
|
||||
r.set(rds.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
||||
|
||||
@@ -280,14 +278,15 @@ def try_generate_user_group_sync_tasks(
|
||||
usergroup_id: int,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
lock_beat: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rug = RedisUserGroup(tenant_id, usergroup_id)
|
||||
if rug.fenced:
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
rug = RedisUserGroup(usergroup_id)
|
||||
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
if r.exists(rug.fence_key):
|
||||
return None
|
||||
|
||||
# race condition with the monitor/cleanup function if we use a cached result!
|
||||
@@ -309,11 +308,12 @@ def try_generate_user_group_sync_tasks(
|
||||
task_logger.info(
|
||||
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
|
||||
)
|
||||
result = rug.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
|
||||
if result is None:
|
||||
tasks_generated = rug.generate_tasks(
|
||||
celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
tasks_generated = result[0]
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
@@ -322,11 +322,11 @@ def try_generate_user_group_sync_tasks(
|
||||
|
||||
task_logger.info(
|
||||
f"RedisUserGroup.generate_tasks finished. "
|
||||
f"usergroup={usergroup.id} tasks_generated={tasks_generated}"
|
||||
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
rug.set_fence(tasks_generated)
|
||||
r.set(rug.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
||||
|
||||
@@ -352,7 +352,7 @@ def monitor_connector_taskset(r: Redis) -> None:
|
||||
|
||||
|
||||
def monitor_document_set_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
|
||||
@@ -362,12 +362,16 @@ def monitor_document_set_taskset(
|
||||
|
||||
document_set_id = int(document_set_id_str)
|
||||
|
||||
rds = RedisDocumentSet(tenant_id, document_set_id)
|
||||
if not rds.fenced:
|
||||
rds = RedisDocumentSet(document_set_id)
|
||||
|
||||
fence_value = r.get(rds.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
initial_count = rds.payload
|
||||
if initial_count is None:
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rds.taskset_key))
|
||||
@@ -395,38 +399,48 @@ def monitor_document_set_taskset(
|
||||
f"Successfully synced document set: document_set={document_set_id}"
|
||||
)
|
||||
|
||||
rds.reset()
|
||||
r.delete(rds.taskset_key)
|
||||
r.delete(rds.fence_key)
|
||||
|
||||
|
||||
def monitor_connector_deletion_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis
|
||||
key_bytes: bytes, r: Redis, tenant_id: str | None
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
cc_pair_id_str = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
|
||||
fence_data = redis_connector.delete.payload
|
||||
if not fence_data:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - fence payload invalid: cc_pair={cc_pair_id}"
|
||||
# read related data and evaluate/print task progress
|
||||
fence_value = cast(bytes, r.get(rcd.fence_key))
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
fence_json = fence_value.decode("utf-8")
|
||||
fence_data = RedisConnectorDeletionFenceData.model_validate_json(
|
||||
cast(str, fence_json)
|
||||
)
|
||||
return
|
||||
except ValueError:
|
||||
task_logger.exception(
|
||||
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
|
||||
)
|
||||
raise
|
||||
|
||||
# the fence is setting up but isn't ready yet
|
||||
if fence_data.num_tasks is None:
|
||||
# the fence is setting up but isn't ready yet
|
||||
return
|
||||
|
||||
remaining = redis_connector.delete.get_remaining()
|
||||
count = cast(int, r.scard(rcd.taskset_key))
|
||||
task_logger.info(
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={fence_data.num_tasks}"
|
||||
)
|
||||
if remaining > 0:
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -442,22 +456,11 @@ def monitor_connector_deletion_taskset(
|
||||
db_session, cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
if len(doc_ids) > 0:
|
||||
# NOTE(rkuo): if this happens, documents somehow got added while
|
||||
# deletion was in progress. Likely a bug gating off pruning and indexing
|
||||
# work before deletion starts.
|
||||
# if this happens, documents somehow got added while deletion was in progress. Likely a bug
|
||||
# gating off pruning and indexing work before deletion starts
|
||||
task_logger.warning(
|
||||
"Connector deletion - documents still found after taskset completion. "
|
||||
"Clearing the current deletion attempt and allowing deletion to restart: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"docs_deleted={fence_data.num_tasks} "
|
||||
f"docs_remaining={len(doc_ids)}"
|
||||
)
|
||||
|
||||
# We don't want to waive off why we get into this state, but resetting
|
||||
# our attempt and letting the deletion restart is a good way to recover
|
||||
redis_connector.delete.reset()
|
||||
raise RuntimeError(
|
||||
"Connector deletion - documents still found after taskset completion"
|
||||
f"Connector deletion - documents still found after taskset completion: "
|
||||
f"cc_pair={cc_pair_id} num={len(doc_ids)}"
|
||||
)
|
||||
|
||||
# clean up the rest of the related Postgres entities
|
||||
@@ -521,14 +524,15 @@ def monitor_connector_deletion_taskset(
|
||||
f"docs_deleted={fence_data.num_tasks}"
|
||||
)
|
||||
|
||||
redis_connector.delete.reset()
|
||||
r.delete(rcd.taskset_key)
|
||||
r.delete(rcd.fence_key)
|
||||
|
||||
|
||||
def monitor_ccpair_pruning_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
cc_pair_id_str = RedisConnectorPruning.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
|
||||
@@ -537,76 +541,46 @@ def monitor_ccpair_pruning_taskset(
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if not redis_connector.prune.fenced:
|
||||
rcp = RedisConnectorPruning(cc_pair_id)
|
||||
|
||||
fence_value = r.get(rcp.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
initial = redis_connector.prune.generator_complete
|
||||
if initial is None:
|
||||
generator_value = r.get(rcp.generator_complete_key)
|
||||
if generator_value is None:
|
||||
return
|
||||
|
||||
remaining = redis_connector.prune.get_remaining()
|
||||
try:
|
||||
initial_count = int(cast(int, generator_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rcp.taskset_key))
|
||||
task_logger.info(
|
||||
f"Connector pruning progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
|
||||
f"Connector pruning progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if remaining > 0:
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
|
||||
task_logger.info(
|
||||
f"Successfully pruned connector credential pair. cc_pair={cc_pair_id}"
|
||||
f"Successfully pruned connector credential pair. cc_pair_id={cc_pair_id}"
|
||||
)
|
||||
|
||||
redis_connector.prune.taskset_clear()
|
||||
redis_connector.prune.generator_clear()
|
||||
redis_connector.prune.set_fence(False)
|
||||
|
||||
|
||||
def monitor_ccpair_permissions_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if not redis_connector.permissions.fenced:
|
||||
return
|
||||
|
||||
initial = redis_connector.permissions.generator_complete
|
||||
if initial is None:
|
||||
return
|
||||
|
||||
remaining = redis_connector.permissions.get_remaining()
|
||||
task_logger.info(
|
||||
f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
|
||||
)
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
payload: RedisConnectorPermissionSyncPayload | None = (
|
||||
redis_connector.permissions.payload
|
||||
)
|
||||
start_time: datetime | None = payload.started if payload else None
|
||||
|
||||
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
|
||||
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
|
||||
|
||||
redis_connector.permissions.reset()
|
||||
r.delete(rcp.taskset_key)
|
||||
r.delete(rcp.generator_progress_key)
|
||||
r.delete(rcp.generator_complete_key)
|
||||
r.delete(rcp.fence_key)
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
composite_id = RedisConnectorIndexing.get_id_from_fence_key(fence_key)
|
||||
if composite_id is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_indexing_taskset: could not parse composite_id from {fence_key}"
|
||||
@@ -621,94 +595,103 @@ def monitor_ccpair_indexing_taskset(
|
||||
cc_pair_id = int(parts[0])
|
||||
search_settings_id = int(parts[1])
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
if not redis_connector_index.fenced:
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
|
||||
|
||||
# read related data and evaluate/print task progress
|
||||
fence_value = cast(bytes, r.get(rci.fence_key))
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
payload = redis_connector_index.payload
|
||||
if not payload:
|
||||
return
|
||||
|
||||
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
|
||||
|
||||
progress = redis_connector_index.get_progress()
|
||||
if progress is not None:
|
||||
task_logger.info(
|
||||
f"Connector indexing progress: cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
try:
|
||||
fence_json = fence_value.decode("utf-8")
|
||||
fence_data = RedisConnectorIndexingFenceData.model_validate_json(
|
||||
cast(str, fence_json)
|
||||
)
|
||||
except ValueError:
|
||||
task_logger.exception(
|
||||
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
|
||||
)
|
||||
raise
|
||||
|
||||
if payload.index_attempt_id is None or payload.celery_task_id is None:
|
||||
elapsed_submitted = datetime.now(timezone.utc) - fence_data.submitted
|
||||
|
||||
generator_progress_value = r.get(rci.generator_progress_key)
|
||||
if generator_progress_value is not None:
|
||||
try:
|
||||
progress_count = int(cast(int, generator_progress_value))
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
|
||||
f"search_settings_id={search_settings_id} "
|
||||
f"progress={progress_count} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
except ValueError:
|
||||
task_logger.error(
|
||||
"monitor_ccpair_indexing_taskset: generator_progress_value is not an integer."
|
||||
)
|
||||
|
||||
if fence_data.index_attempt_id is None or fence_data.celery_task_id is None:
|
||||
# the task is still setting up
|
||||
return
|
||||
|
||||
# Read result state BEFORE generator_complete_key to avoid a race condition
|
||||
# never use any blocking methods on the result from inside a task!
|
||||
result: AsyncResult = AsyncResult(payload.celery_task_id)
|
||||
result: AsyncResult = AsyncResult(fence_data.celery_task_id)
|
||||
result_state = result.state
|
||||
|
||||
# inner/outer/inner double check pattern to avoid race conditions when checking for
|
||||
# bad state
|
||||
generator_complete_value = r.get(rci.generator_complete_key)
|
||||
if generator_complete_value is None:
|
||||
if result_state in READY_STATES:
|
||||
# IF the task state is READY, THEN generator_complete should be set
|
||||
# if it isn't, then the worker crashed
|
||||
task_logger.info(
|
||||
f"Connector indexing aborted: "
|
||||
f"cc_pair_id={cc_pair_id} "
|
||||
f"search_settings_id={search_settings_id} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
|
||||
# inner = get_completion / generator_complete not signaled
|
||||
# outer = result.state in READY state
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int is None: # inner signal not set ... possible error
|
||||
task_state = result.state
|
||||
if (
|
||||
task_state in READY_STATES
|
||||
): # outer signal in terminal state ... possible error
|
||||
# Now double check!
|
||||
if redis_connector_index.get_completion() is None:
|
||||
# inner signal still not set (and cannot change when outer result_state is READY)
|
||||
# Task is finished but generator complete isn't set.
|
||||
# We have a problem! Worker may have crashed.
|
||||
task_result = str(result.result)
|
||||
task_traceback = str(result.traceback)
|
||||
|
||||
msg = (
|
||||
f"Connector indexing aborted or exceptioned: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"celery_task={payload.celery_task_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"result.state={task_state} "
|
||||
f"result.result={task_result} "
|
||||
f"result.traceback={task_traceback}"
|
||||
index_attempt = get_index_attempt(db_session, fence_data.index_attempt_id)
|
||||
if index_attempt:
|
||||
mark_attempt_failed(
|
||||
index_attempt=index_attempt,
|
||||
db_session=db_session,
|
||||
failure_reason="Connector indexing aborted or exceptioned.",
|
||||
)
|
||||
task_logger.warning(msg)
|
||||
|
||||
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
|
||||
if index_attempt:
|
||||
if (
|
||||
index_attempt.status != IndexingStatus.CANCELED
|
||||
and index_attempt.status != IndexingStatus.FAILED
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=payload.index_attempt_id,
|
||||
db_session=db_session,
|
||||
failure_reason=msg,
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
r.delete(rci.generator_lock_key)
|
||||
r.delete(rci.taskset_key)
|
||||
r.delete(rci.generator_progress_key)
|
||||
r.delete(rci.generator_complete_key)
|
||||
r.delete(rci.fence_key)
|
||||
return
|
||||
|
||||
status_enum = HTTPStatus(status_int)
|
||||
status_enum = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
try:
|
||||
status_value = int(cast(int, generator_complete_value))
|
||||
status_enum = HTTPStatus(status_value)
|
||||
except ValueError:
|
||||
task_logger.error(
|
||||
f"monitor_ccpair_indexing_taskset: "
|
||||
f"generator_complete_value=f{generator_complete_value} could not be parsed."
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing finished: cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"Connector indexing finished: cc_pair_id={cc_pair_id} "
|
||||
f"search_settings_id={search_settings_id} "
|
||||
f"status={status_enum.name} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
r.delete(rci.generator_lock_key)
|
||||
r.delete(rci.taskset_key)
|
||||
r.delete(rci.generator_progress_key)
|
||||
r.delete(rci.generator_complete_key)
|
||||
r.delete(rci.fence_key)
|
||||
|
||||
|
||||
@shared_task(name=DanswerCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
|
||||
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
|
||||
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||
It scans for fence values and then gets the counts of any associated tasksets.
|
||||
@@ -717,11 +700,11 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
|
||||
do anything too expensive in this function!
|
||||
|
||||
Returns True if the task actually did work, False if it exited early to prevent overlap
|
||||
Returns True if the task actually did work, False
|
||||
"""
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
lock_beat: redis.lock.Lock = r.lock(
|
||||
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -733,7 +716,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
|
||||
# print current queue lengths
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_celery = celery_get_queue_length("celery", r)
|
||||
n_indexing = celery_get_queue_length(
|
||||
DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
@@ -746,33 +729,49 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
n_pruning = celery_get_queue_length(
|
||||
DanswerCeleryQueues.CONNECTOR_PRUNING, r_celery
|
||||
)
|
||||
n_permissions_sync = celery_get_queue_length(
|
||||
DanswerCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
f"pruning={n_pruning}"
|
||||
)
|
||||
|
||||
# do some cleanup before clearing fences
|
||||
# check the db for any outstanding index attempts
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
attempts: list[IndexAttempt] = []
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
|
||||
)
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
|
||||
)
|
||||
|
||||
for a in attempts:
|
||||
# if attempts exist in the db but we don't detect them in redis, mark them as failed
|
||||
rci = RedisConnectorIndexing(
|
||||
a.connector_credential_pair_id, a.search_settings_id
|
||||
)
|
||||
failure_reason = f"Unknown index attempt {a.id}. Might be left over from a process restart."
|
||||
if not r.exists(rci.fence_key):
|
||||
mark_attempt_failed(a, db_session, failure_reason=failure_reason)
|
||||
|
||||
lock_beat.reacquire()
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"):
|
||||
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
monitor_connector_deletion_taskset(key_bytes, r, tenant_id)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
monitor_document_set_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
@@ -783,25 +782,19 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
noop_fallback,
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
monitor_usergroup_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
|
||||
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
|
||||
monitor_ccpair_indexing_taskset(key_bytes, r, db_session)
|
||||
|
||||
# uncomment for debugging if needed
|
||||
# r_celery = celery_app.broker_connection().channel().client
|
||||
@@ -819,7 +812,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
name="vespa_metadata_sync_task",
|
||||
bind=True,
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
@@ -873,9 +866,7 @@ def vespa_metadata_sync_task(
|
||||
)
|
||||
except Exception as ex:
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(
|
||||
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
|
||||
)
|
||||
task_logger.warning(f"Retry failed: {ex.last_attempt.attempt_number}")
|
||||
|
||||
# only set the inner exception if it is of type Exception
|
||||
e_temp = ex.last_attempt.exception()
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Factory stub for running celery worker / celery beat."""
|
||||
from celery import Celery
|
||||
|
||||
from danswer.background.celery.apps.beat import celery_app
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app: Celery = celery_app
|
||||
app = fetch_versioned_implementation(
|
||||
"danswer.background.celery.apps.beat", "celery_app"
|
||||
)
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
"""Factory stub for running celery worker / celery beat."""
|
||||
from celery import Celery
|
||||
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app: Celery = fetch_versioned_implementation(
|
||||
app = fetch_versioned_implementation(
|
||||
"danswer.background.celery.apps.primary", "celery_app"
|
||||
)
|
||||
|
||||
@@ -11,8 +11,7 @@ from typing import Any
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -29,26 +28,16 @@ JobStatusType = (
|
||||
def _initializer(
|
||||
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Initialize the child process with a fresh SQLAlchemy Engine.
|
||||
"""Ensure the parent proc's database connections are not touched
|
||||
in the new connection pool
|
||||
|
||||
Based on SQLAlchemy's recommendations to handle multiprocessing:
|
||||
Based on the recommended approach in the SQLAlchemy docs found:
|
||||
https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
logger.info("Initializing spawned worker child process.")
|
||||
|
||||
# Reset the engine in the child process
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
# Optionally set a custom app name for database logging purposes
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
|
||||
|
||||
# Initialize a new engine with desired parameters
|
||||
SqlEngine.init_engine(pool_size=4, max_overflow=12, pool_recycle=60)
|
||||
|
||||
# Proceed with executing the target function
|
||||
get_sqlalchemy_engine().dispose(close=False)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -82,7 +71,7 @@ class SimpleJob:
|
||||
return "running"
|
||||
elif self.process.exitcode is None:
|
||||
return "cancelled"
|
||||
elif self.process.exitcode != 0:
|
||||
elif self.process.exitcode > 0:
|
||||
return "error"
|
||||
else:
|
||||
return "finished"
|
||||
@@ -123,8 +112,7 @@ class SimpleJobClient:
|
||||
self._cleanup_completed_jobs()
|
||||
if len(self.jobs) >= self.n_workers:
|
||||
logger.debug(
|
||||
f"No available workers to run job. "
|
||||
f"Currently running '{len(self.jobs)}' jobs, with a limit of '{self.n_workers}'."
|
||||
f"No available workers to run job. Currently running '{len(self.jobs)}' jobs, with a limit of '{self.n_workers}'."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import time
|
||||
import traceback
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -19,7 +21,6 @@ from danswer.db.connector_credential_pair import get_last_successful_attempt_tim
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.index_attempt import mark_attempt_canceled
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.index_attempt import mark_attempt_partially_succeeded
|
||||
from danswer.db.index_attempt import mark_attempt_succeeded
|
||||
@@ -30,10 +31,10 @@ from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.embedder import DefaultIndexingEmbedder
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
|
||||
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.utils.logger import IndexAttemptSingleton
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.logger import TaskAttemptSingleton
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -41,6 +42,19 @@ logger = setup_logger()
|
||||
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
|
||||
|
||||
|
||||
class RunIndexingCallbackInterface(ABC):
|
||||
"""Defines a callback interface to be passed to
|
||||
to run_indexing_entrypoint."""
|
||||
|
||||
@abstractmethod
|
||||
def should_stop(self) -> bool:
|
||||
"""Signal to stop the looping function in flight."""
|
||||
|
||||
@abstractmethod
|
||||
def progress(self, amount: int) -> None:
|
||||
"""Send progress updates to the caller."""
|
||||
|
||||
|
||||
def _get_connector_runner(
|
||||
db_session: Session,
|
||||
attempt: IndexAttempt,
|
||||
@@ -88,15 +102,11 @@ def _get_connector_runner(
|
||||
)
|
||||
|
||||
|
||||
class ConnectorStopSignal(Exception):
|
||||
"""A custom exception used to signal a stop in processing."""
|
||||
|
||||
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt: IndexAttempt,
|
||||
tenant_id: str | None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
callback: RunIndexingCallbackInterface | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
1. Get documents which are either new or updated from specified application
|
||||
@@ -108,13 +118,7 @@ def _run_indexing(
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
if index_attempt.search_settings is None:
|
||||
raise ValueError(
|
||||
"Search settings must be set for indexing. This should not be possible."
|
||||
)
|
||||
|
||||
search_settings = index_attempt.search_settings
|
||||
|
||||
index_name = search_settings.index_name
|
||||
|
||||
# Only update cc-pair status for primary index jobs
|
||||
@@ -128,7 +132,13 @@ def _run_indexing(
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=search_settings,
|
||||
callback=callback,
|
||||
heartbeat=IndexingHeartbeat(
|
||||
index_attempt_id=index_attempt.id,
|
||||
db_session=db_session,
|
||||
# let the world know we're still making progress after
|
||||
# every 10 batches
|
||||
freq=10,
|
||||
),
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
@@ -141,7 +151,6 @@ def _run_indexing(
|
||||
),
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
db_cc_pair = index_attempt.connector_credential_pair
|
||||
@@ -213,7 +222,7 @@ def _run_indexing(
|
||||
# contents still need to be initially pulled.
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise ConnectorStopSignal("Connector stop signal detected")
|
||||
raise RuntimeError("Connector stop signal detected")
|
||||
|
||||
# TODO: should we move this into the above callback instead?
|
||||
db_session.refresh(db_cc_pair)
|
||||
@@ -274,7 +283,7 @@ def _run_indexing(
|
||||
db_session.commit()
|
||||
|
||||
if callback:
|
||||
callback.progress("_run_indexing", len(doc_batch))
|
||||
callback.progress(len(doc_batch))
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
update_docs_indexed(
|
||||
@@ -307,16 +316,26 @@ def _run_indexing(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
|
||||
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
|
||||
)
|
||||
|
||||
if isinstance(e, ConnectorStopSignal):
|
||||
mark_attempt_canceled(
|
||||
index_attempt.id,
|
||||
# Only mark the attempt as a complete failure if this is the first indexing window.
|
||||
# Otherwise, some progress was made - the next run will not start from the beginning.
|
||||
# In this case, it is not accurate to mark it as a failure. When the next run begins,
|
||||
# if that fails immediately, it will be marked as a failure.
|
||||
#
|
||||
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
|
||||
# to give better clarity in the UI, as the next run will never happen.
|
||||
if (
|
||||
ind == 0
|
||||
or not db_cc_pair.status.is_active()
|
||||
or index_attempt.status != IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt,
|
||||
db_session,
|
||||
reason=str(e),
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
)
|
||||
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
@@ -328,37 +347,6 @@ def _run_indexing(
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.stop()
|
||||
raise e
|
||||
else:
|
||||
# Only mark the attempt as a complete failure if this is the first indexing window.
|
||||
# Otherwise, some progress was made - the next run will not start from the beginning.
|
||||
# In this case, it is not accurate to mark it as a failure. When the next run begins,
|
||||
# if that fails immediately, it will be marked as a failure.
|
||||
#
|
||||
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
|
||||
# to give better clarity in the UI, as the next run will never happen.
|
||||
if (
|
||||
ind == 0
|
||||
or not db_cc_pair.status.is_active()
|
||||
or index_attempt.status != IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt.id,
|
||||
db_session,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
)
|
||||
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.stop()
|
||||
raise e
|
||||
|
||||
# break => similar to success case. As mentioned above, if the next run fails for the same
|
||||
# reason it will then be marked as a failure
|
||||
@@ -378,7 +366,7 @@ def _run_indexing(
|
||||
and index_attempt_md.num_exceptions >= batch_num
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt.id,
|
||||
index_attempt,
|
||||
db_session,
|
||||
failure_reason="All batches exceptioned.",
|
||||
)
|
||||
@@ -425,7 +413,7 @@ def run_indexing_entrypoint(
|
||||
tenant_id: str | None,
|
||||
connector_credential_pair_id: int,
|
||||
is_ee: bool = False,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
callback: RunIndexingCallbackInterface | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
if is_ee:
|
||||
@@ -433,19 +421,17 @@ def run_indexing_entrypoint(
|
||||
|
||||
# set the indexing attempt ID so that all log messages from this process
|
||||
# will have it added as a prefix
|
||||
TaskAttemptSingleton.set_cc_and_index_id(
|
||||
IndexAttemptSingleton.set_cc_and_index_id(
|
||||
index_attempt_id, connector_credential_pair_id
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
|
||||
|
||||
tenant_str = ""
|
||||
if tenant_id is not None:
|
||||
tenant_str = f" for tenant {tenant_id}"
|
||||
|
||||
logger.info(
|
||||
f"Indexing starting{tenant_str}: "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"Indexing starting for tenant {tenant_id}: "
|
||||
if tenant_id is not None
|
||||
else ""
|
||||
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
@@ -453,8 +439,10 @@ def run_indexing_entrypoint(
|
||||
_run_indexing(db_session, attempt, tenant_id, callback)
|
||||
|
||||
logger.info(
|
||||
f"Indexing finished{tenant_str}: "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"Indexing finished for tenant {tenant_id}: "
|
||||
if tenant_id is not None
|
||||
else ""
|
||||
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
|
||||
@@ -14,6 +14,15 @@ from danswer.db.tasks import mark_task_start
|
||||
from danswer.db.tasks import register_task
|
||||
|
||||
|
||||
def name_cc_prune_task(
|
||||
connector_id: int | None = None, credential_id: int | None = None
|
||||
) -> str:
|
||||
task_name = f"prune_connector_credential_pair_{connector_id}_{credential_id}"
|
||||
if not connector_id or not credential_id:
|
||||
task_name = "prune_connector_credential_pair"
|
||||
return task_name
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Callable)
|
||||
|
||||
|
||||
|
||||
@@ -1,315 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import ToolCall
|
||||
|
||||
from danswer.chat.llm_response_handler import LLMResponseHandlerManager
|
||||
from danswer.chat.models import AnswerQuestionPossibleReturn
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.chat.prompt_builder.build import default_build_system_message
|
||||
from danswer.chat.prompt_builder.build import default_build_user_message
|
||||
from danswer.chat.prompt_builder.build import LLMCall
|
||||
from danswer.chat.stream_processing.answer_response_handler import (
|
||||
CitationResponseHandler,
|
||||
)
|
||||
from danswer.chat.stream_processing.answer_response_handler import (
|
||||
DummyAnswerResponseHandler,
|
||||
)
|
||||
from danswer.chat.stream_processing.utils import map_document_id_order
|
||||
from danswer.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
from danswer.file_store.utils import InMemoryChatFile
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.tools.tool_runner import ToolCallKickoff
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse]
|
||||
|
||||
|
||||
class Answer:
|
||||
def __init__(
|
||||
self,
|
||||
question: str,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
llm: LLM,
|
||||
prompt_config: PromptConfig,
|
||||
force_use_tool: ForceUseTool,
|
||||
# must be the same length as `docs`. If None, all docs are considered "relevant"
|
||||
message_history: list[PreviousMessage] | None = None,
|
||||
single_message_history: str | None = None,
|
||||
# newly passed in files to include as part of this question
|
||||
# TODO THIS NEEDS TO BE HANDLED
|
||||
latest_query_files: list[InMemoryChatFile] | None = None,
|
||||
files: list[InMemoryChatFile] | None = None,
|
||||
tools: list[Tool] | None = None,
|
||||
# NOTE: for native tool-calling, this is only supported by OpenAI atm,
|
||||
# but we only support them anyways
|
||||
# if set to True, then never use the LLMs provided tool-calling functonality
|
||||
skip_explicit_tool_calling: bool = False,
|
||||
# Returns the full document sections text from the search tool
|
||||
return_contexts: bool = False,
|
||||
skip_gen_ai_answer_generation: bool = False,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
) -> None:
|
||||
if single_message_history and message_history:
|
||||
raise ValueError(
|
||||
"Cannot provide both `message_history` and `single_message_history`"
|
||||
)
|
||||
|
||||
self.question = question
|
||||
self.is_connected: Callable[[], bool] | None = is_connected
|
||||
|
||||
self.latest_query_files = latest_query_files or []
|
||||
self.file_id_to_file = {file.file_id: file for file in (files or [])}
|
||||
|
||||
self.tools = tools or []
|
||||
self.force_use_tool = force_use_tool
|
||||
|
||||
self.message_history = message_history or []
|
||||
# used for QA flow where we only want to send a single message
|
||||
self.single_message_history = single_message_history
|
||||
|
||||
self.answer_style_config = answer_style_config
|
||||
self.prompt_config = prompt_config
|
||||
|
||||
self.llm = llm
|
||||
self.llm_tokenizer = get_tokenizer(
|
||||
provider_type=llm.config.model_provider,
|
||||
model_name=llm.config.model_name,
|
||||
)
|
||||
|
||||
self._final_prompt: list[BaseMessage] | None = None
|
||||
|
||||
self._streamed_output: list[str] | None = None
|
||||
self._processed_stream: (
|
||||
list[AnswerQuestionPossibleReturn | ToolResponse | ToolCallKickoff] | None
|
||||
) = None
|
||||
|
||||
self._return_contexts = return_contexts
|
||||
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
|
||||
self._is_cancelled = False
|
||||
|
||||
self.using_tool_calling_llm = (
|
||||
explicit_tool_calling_supported(
|
||||
self.llm.config.model_provider, self.llm.config.model_name
|
||||
)
|
||||
and not skip_explicit_tool_calling
|
||||
)
|
||||
|
||||
def _get_tools_list(self) -> list[Tool]:
|
||||
if not self.force_use_tool.force_use:
|
||||
return self.tools
|
||||
|
||||
tool = next(
|
||||
(t for t in self.tools if t.name == self.force_use_tool.tool_name), None
|
||||
)
|
||||
if tool is None:
|
||||
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
|
||||
|
||||
logger.info(
|
||||
f"Forcefully using tool='{tool.name}'"
|
||||
+ (
|
||||
f" with args='{self.force_use_tool.args}'"
|
||||
if self.force_use_tool.args is not None
|
||||
else ""
|
||||
)
|
||||
)
|
||||
return [tool]
|
||||
|
||||
def _handle_specified_tool_call(
|
||||
self, llm_calls: list[LLMCall], tool: Tool, tool_args: dict
|
||||
) -> AnswerStream:
|
||||
current_llm_call = llm_calls[-1]
|
||||
|
||||
# make a dummy tool handler
|
||||
tool_handler = ToolResponseHandler([tool])
|
||||
|
||||
dummy_tool_call_chunk = AIMessageChunk(content="")
|
||||
dummy_tool_call_chunk.tool_calls = [
|
||||
ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
|
||||
]
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
tool_handler, DummyAnswerResponseHandler(), self.is_cancelled
|
||||
)
|
||||
yield from response_handler_manager.handle_llm_response(
|
||||
iter([dummy_tool_call_chunk])
|
||||
)
|
||||
|
||||
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
if new_llm_call:
|
||||
yield from self._get_response(llm_calls + [new_llm_call])
|
||||
else:
|
||||
raise RuntimeError("Tool call handler did not return a new LLM call")
|
||||
|
||||
def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
|
||||
current_llm_call = llm_calls[-1]
|
||||
|
||||
# handle the case where no decision has to be made; we simply run the tool
|
||||
if (
|
||||
current_llm_call.force_use_tool.force_use
|
||||
and current_llm_call.force_use_tool.args is not None
|
||||
):
|
||||
tool_name, tool_args = (
|
||||
current_llm_call.force_use_tool.tool_name,
|
||||
current_llm_call.force_use_tool.args,
|
||||
)
|
||||
tool = next(
|
||||
(t for t in current_llm_call.tools if t.name == tool_name), None
|
||||
)
|
||||
if not tool:
|
||||
raise RuntimeError(f"Tool '{tool_name}' not found")
|
||||
|
||||
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
|
||||
return
|
||||
|
||||
# special pre-logic for non-tool calling LLM case
|
||||
if not self.using_tool_calling_llm and current_llm_call.tools:
|
||||
chosen_tool_and_args = (
|
||||
ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
|
||||
current_llm_call, self.llm
|
||||
)
|
||||
)
|
||||
if chosen_tool_and_args:
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
|
||||
return
|
||||
|
||||
# if we're skipping gen ai answer generation, we should break
|
||||
# out unless we're forcing a tool call. If we don't, we might generate an
|
||||
# answer, which is a no-no!
|
||||
if (
|
||||
self.skip_gen_ai_answer_generation
|
||||
and not current_llm_call.force_use_tool.force_use
|
||||
):
|
||||
return
|
||||
|
||||
# set up "handlers" to listen to the LLM response stream and
|
||||
# feed back the processed results + handle tool call requests
|
||||
# + figure out what the next LLM call should be
|
||||
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
||||
|
||||
search_result = SearchTool.get_search_result(current_llm_call) or []
|
||||
|
||||
# Quotes are no longer supported
|
||||
# answer_handler: AnswerResponseHandler
|
||||
# if self.answer_style_config.citation_config:
|
||||
# answer_handler = CitationResponseHandler(
|
||||
# context_docs=search_result,
|
||||
# doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
# )
|
||||
# elif self.answer_style_config.quotes_config:
|
||||
# answer_handler = QuotesResponseHandler(
|
||||
# context_docs=search_result,
|
||||
# )
|
||||
# else:
|
||||
# raise ValueError("No answer style config provided")
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=search_result,
|
||||
doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
)
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
tool_call_handler, answer_handler, self.is_cancelled
|
||||
)
|
||||
|
||||
# DEBUG: good breakpoint
|
||||
stream = self.llm.stream(
|
||||
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||
prompt=current_llm_call.prompt_builder.build(),
|
||||
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
|
||||
tool_choice=(
|
||||
"required"
|
||||
if current_llm_call.tools and current_llm_call.force_use_tool.force_use
|
||||
else None
|
||||
),
|
||||
structured_response_format=self.answer_style_config.structured_response_format,
|
||||
)
|
||||
yield from response_handler_manager.handle_llm_response(stream)
|
||||
|
||||
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
if new_llm_call:
|
||||
yield from self._get_response(llm_calls + [new_llm_call])
|
||||
|
||||
@property
|
||||
def processed_streamed_output(self) -> AnswerStream:
|
||||
if self._processed_stream is not None:
|
||||
yield from self._processed_stream
|
||||
return
|
||||
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
user_query=self.question,
|
||||
prompt_config=self.prompt_config,
|
||||
files=self.latest_query_files,
|
||||
),
|
||||
message_history=self.message_history,
|
||||
llm_config=self.llm.config,
|
||||
single_message_history=self.single_message_history,
|
||||
raw_user_text=self.question,
|
||||
)
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
llm_call = LLMCall(
|
||||
prompt_builder=prompt_builder,
|
||||
tools=self._get_tools_list(),
|
||||
force_use_tool=self.force_use_tool,
|
||||
files=self.latest_query_files,
|
||||
tool_call_info=[],
|
||||
using_tool_calling_llm=self.using_tool_calling_llm,
|
||||
)
|
||||
|
||||
processed_stream = []
|
||||
for processed_packet in self._get_response([llm_call]):
|
||||
processed_stream.append(processed_packet)
|
||||
yield processed_packet
|
||||
|
||||
self._processed_stream = processed_stream
|
||||
|
||||
@property
|
||||
def llm_answer(self) -> str:
|
||||
answer = ""
|
||||
for packet in self.processed_streamed_output:
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
|
||||
return answer
|
||||
|
||||
@property
|
||||
def citations(self) -> list[CitationInfo]:
|
||||
citations: list[CitationInfo] = []
|
||||
for packet in self.processed_streamed_output:
|
||||
if isinstance(packet, CitationInfo):
|
||||
citations.append(packet)
|
||||
|
||||
return citations
|
||||
|
||||
def is_cancelled(self) -> bool:
|
||||
if self._is_cancelled:
|
||||
return True
|
||||
|
||||
if self.is_connected is not None:
|
||||
if not self.is_connected():
|
||||
logger.debug("Answer stream has been cancelled")
|
||||
self._is_cancelled = not self.is_connected()
|
||||
|
||||
return self._is_cancelled
|
||||
@@ -2,79 +2,20 @@ import re
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.datastructures import Headers
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import is_user_admin
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PersonaOverrideConfig
|
||||
from danswer.chat.models import ThreadMessage
|
||||
from danswer.configs.constants import DEFAULT_PERSONA_ID
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.llm import fetch_existing_doc_sets
|
||||
from danswer.db.llm import fetch_existing_tools
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import Tool
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_prompts_by_ids
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.natural_language_processing.utils import BaseTokenizer
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def prepare_chat_message_request(
|
||||
message_text: str,
|
||||
user: User | None,
|
||||
persona_id: int | None,
|
||||
# Does the question need to have a persona override
|
||||
persona_override_config: PersonaOverrideConfig | None,
|
||||
prompt: Prompt | None,
|
||||
message_ts_to_respond_to: str | None,
|
||||
retrieval_details: RetrievalDetails | None,
|
||||
rerank_settings: RerankingDetails | None,
|
||||
db_session: Session,
|
||||
) -> CreateChatMessageRequest:
|
||||
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
|
||||
new_chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description=None,
|
||||
user_id=user.id if user else None,
|
||||
# If using an override, this id will be ignored later on
|
||||
persona_id=persona_id or DEFAULT_PERSONA_ID,
|
||||
danswerbot_flow=True,
|
||||
slack_thread_id=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
return CreateChatMessageRequest(
|
||||
chat_session_id=new_chat_session.id,
|
||||
parent_message_id=None, # It's a standalone chat session each time
|
||||
message=message_text,
|
||||
file_descriptors=[], # Currently SlackBot/answer api do not support files in the context
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
# Can always override the persona for the single query, if it's a normal persona
|
||||
# then it will be treated the same
|
||||
persona_override_config=persona_override_config,
|
||||
search_doc_ids=None,
|
||||
retrieval_options=retrieval_details,
|
||||
rerank_settings=rerank_settings,
|
||||
)
|
||||
|
||||
|
||||
def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDoc:
|
||||
return LlmDoc(
|
||||
document_id=inference_section.center_chunk.document_id,
|
||||
@@ -90,49 +31,9 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
|
||||
if inference_section.center_chunk.source_links
|
||||
else None,
|
||||
source_links=inference_section.center_chunk.source_links,
|
||||
match_highlights=inference_section.center_chunk.match_highlights,
|
||||
)
|
||||
|
||||
|
||||
def combine_message_thread(
|
||||
messages: list[ThreadMessage],
|
||||
max_tokens: int | None,
|
||||
llm_tokenizer: BaseTokenizer,
|
||||
) -> str:
|
||||
"""Used to create a single combined message context from threads"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
message_strs: list[str] = []
|
||||
total_token_count = 0
|
||||
|
||||
for message in reversed(messages):
|
||||
if message.role == MessageType.USER:
|
||||
role_str = message.role.value.upper()
|
||||
if message.sender:
|
||||
role_str += " " + message.sender
|
||||
else:
|
||||
# Since other messages might have the user identifying information
|
||||
# better to use Unknown for symmetry
|
||||
role_str += " Unknown"
|
||||
else:
|
||||
role_str = message.role.value.upper()
|
||||
|
||||
msg_str = f"{role_str}:\n{message.message}"
|
||||
message_token_count = len(llm_tokenizer.encode(msg_str))
|
||||
|
||||
if (
|
||||
max_tokens is not None
|
||||
and total_token_count + message_token_count > max_tokens
|
||||
):
|
||||
break
|
||||
|
||||
message_strs.insert(0, msg_str)
|
||||
total_token_count += message_token_count
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
def create_chat_chain(
|
||||
chat_session_id: UUID,
|
||||
db_session: Session,
|
||||
@@ -295,71 +196,3 @@ def extract_headers(
|
||||
if lowercase_key in headers:
|
||||
extracted_headers[lowercase_key] = headers[lowercase_key]
|
||||
return extracted_headers
|
||||
|
||||
|
||||
def create_temporary_persona(
|
||||
persona_config: PersonaOverrideConfig, db_session: Session, user: User | None = None
|
||||
) -> Persona:
|
||||
if not is_user_admin(user):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="User is not authorized to create a persona in one shot queries",
|
||||
)
|
||||
|
||||
"""Create a temporary Persona object from the provided configuration."""
|
||||
persona = Persona(
|
||||
name=persona_config.name,
|
||||
description=persona_config.description,
|
||||
num_chunks=persona_config.num_chunks,
|
||||
llm_relevance_filter=persona_config.llm_relevance_filter,
|
||||
llm_filter_extraction=persona_config.llm_filter_extraction,
|
||||
recency_bias=persona_config.recency_bias,
|
||||
llm_model_provider_override=persona_config.llm_model_provider_override,
|
||||
llm_model_version_override=persona_config.llm_model_version_override,
|
||||
)
|
||||
|
||||
if persona_config.prompts:
|
||||
persona.prompts = [
|
||||
Prompt(
|
||||
name=p.name,
|
||||
description=p.description,
|
||||
system_prompt=p.system_prompt,
|
||||
task_prompt=p.task_prompt,
|
||||
include_citations=p.include_citations,
|
||||
datetime_aware=p.datetime_aware,
|
||||
)
|
||||
for p in persona_config.prompts
|
||||
]
|
||||
elif persona_config.prompt_ids:
|
||||
persona.prompts = get_prompts_by_ids(
|
||||
db_session=db_session, prompt_ids=persona_config.prompt_ids
|
||||
)
|
||||
|
||||
persona.tools = []
|
||||
if persona_config.custom_tools_openapi:
|
||||
for schema in persona_config.custom_tools_openapi:
|
||||
tools = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema_and_headers(schema),
|
||||
)
|
||||
persona.tools.extend(tools)
|
||||
|
||||
if persona_config.tools:
|
||||
tool_ids = [tool.id for tool in persona_config.tools]
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
|
||||
)
|
||||
|
||||
if persona_config.tool_ids:
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(
|
||||
db_session=db_session, tool_ids=persona_config.tool_ids
|
||||
)
|
||||
)
|
||||
|
||||
fetched_docs = fetch_existing_doc_sets(
|
||||
db_session=db_session, doc_ids=persona_config.document_set_ids
|
||||
)
|
||||
persona.document_sets = fetched_docs
|
||||
|
||||
return persona
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from danswer.chat.models import ResponsePart
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.chat.prompt_builder.build import LLMCall
|
||||
from danswer.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from danswer.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
|
||||
|
||||
class LLMResponseHandlerManager:
|
||||
def __init__(
|
||||
self,
|
||||
tool_handler: ToolResponseHandler,
|
||||
answer_handler: AnswerResponseHandler,
|
||||
is_cancelled: Callable[[], bool],
|
||||
):
|
||||
self.tool_handler = tool_handler
|
||||
self.answer_handler = answer_handler
|
||||
self.is_cancelled = is_cancelled
|
||||
|
||||
def handle_llm_response(
|
||||
self,
|
||||
stream: Iterator[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
all_messages: list[BaseMessage] = []
|
||||
for message in stream:
|
||||
if self.is_cancelled():
|
||||
yield StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
return
|
||||
# tool handler doesn't do anything until the full message is received
|
||||
# NOTE: still need to run list() to get this to run
|
||||
list(self.tool_handler.handle_response_part(message, all_messages))
|
||||
yield from self.answer_handler.handle_response_part(message, all_messages)
|
||||
all_messages.append(message)
|
||||
|
||||
# potentially give back all info on the selected tool call + its result
|
||||
yield from self.tool_handler.handle_response_part(None, all_messages)
|
||||
yield from self.answer_handler.handle_response_part(None, all_messages)
|
||||
|
||||
def next_llm_call(self, llm_call: LLMCall) -> LLMCall | None:
|
||||
return self.tool_handler.next_llm_call(llm_call)
|
||||
@@ -5,7 +5,6 @@ from danswer.configs.chat_configs import INPUT_PROMPT_YAML
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.chat_configs import PERSONAS_YAML
|
||||
from danswer.configs.chat_configs import PROMPTS_YAML
|
||||
from danswer.context.search.enums import RecencyBiasSetting
|
||||
from danswer.db.document_set import get_or_create_document_set_by_name
|
||||
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
|
||||
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
||||
@@ -15,6 +14,7 @@ from danswer.db.models import Tool as ToolDBModel
|
||||
from danswer.db.persona import get_prompt_by_name
|
||||
from danswer.db.persona import upsert_persona
|
||||
from danswer.db.persona import upsert_prompt
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
|
||||
|
||||
def load_prompts_from_yaml(
|
||||
@@ -79,12 +79,8 @@ def load_personas_from_yaml(
|
||||
if prompts:
|
||||
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
|
||||
|
||||
if not prompt_ids:
|
||||
raise ValueError("Invalid Persona config, no prompts exist")
|
||||
|
||||
p_id = persona.get("id")
|
||||
tool_ids = []
|
||||
|
||||
if persona.get("image_generation"):
|
||||
image_gen_tool = (
|
||||
db_session.query(ToolDBModel)
|
||||
@@ -126,16 +122,12 @@ def load_personas_from_yaml(
|
||||
tool_ids=tool_ids,
|
||||
builtin_persona=True,
|
||||
is_public=True,
|
||||
display_priority=(
|
||||
existing_persona.display_priority
|
||||
if existing_persona is not None
|
||||
else persona.get("display_priority")
|
||||
),
|
||||
is_visible=(
|
||||
existing_persona.is_visible
|
||||
if existing_persona is not None
|
||||
else persona.get("is_visible")
|
||||
),
|
||||
display_priority=existing_persona.display_priority
|
||||
if existing_persona is not None
|
||||
else persona.get("display_priority"),
|
||||
is_visible=existing_persona.is_visible
|
||||
if existing_persona is not None
|
||||
else persona.get("is_visible"),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
@@ -1,29 +1,16 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.context.search.enums import QueryFlow
|
||||
from danswer.context.search.enums import RecencyBiasSetting
|
||||
from danswer.context.search.enums import SearchType
|
||||
from danswer.context.search.models import RetrievalDocs
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SearchResponse
|
||||
from danswer.tools.custom.base_tool_types import ToolResultType
|
||||
|
||||
|
||||
class LlmDoc(BaseModel):
|
||||
@@ -38,7 +25,6 @@ class LlmDoc(BaseModel):
|
||||
updated_at: datetime | None
|
||||
link: str | None
|
||||
source_links: dict[int, str] | None
|
||||
match_highlights: list[str] | None
|
||||
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
@@ -131,6 +117,20 @@ class StreamingError(BaseModel):
|
||||
stack_trace: str | None = None
|
||||
|
||||
|
||||
class DanswerQuote(BaseModel):
|
||||
# This is during inference so everything is a string by this point
|
||||
quote: str
|
||||
document_id: str
|
||||
link: str | None
|
||||
source_type: str
|
||||
semantic_identifier: str
|
||||
blurb: str
|
||||
|
||||
|
||||
class DanswerQuotes(BaseModel):
|
||||
quotes: list[DanswerQuote]
|
||||
|
||||
|
||||
class DanswerContext(BaseModel):
|
||||
content: str
|
||||
document_id: str
|
||||
@@ -146,23 +146,17 @@ class DanswerAnswer(BaseModel):
|
||||
answer: str | None
|
||||
|
||||
|
||||
class ThreadMessage(BaseModel):
|
||||
message: str
|
||||
sender: str | None = None
|
||||
role: MessageType = MessageType.USER
|
||||
|
||||
|
||||
class ChatDanswerBotResponse(BaseModel):
|
||||
answer: str | None = None
|
||||
citations: list[CitationInfo] | None = None
|
||||
docs: QADocsResponse | None = None
|
||||
class QAResponse(SearchResponse, DanswerAnswer):
|
||||
quotes: list[DanswerQuote] | None
|
||||
contexts: list[DanswerContexts] | None
|
||||
predicted_flow: QueryFlow
|
||||
predicted_search: SearchType
|
||||
eval_res_valid: bool | None = None
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
chat_message_id: int | None = None
|
||||
answer_valid: bool = True # Reflexion result, default True if Reflexion not run
|
||||
|
||||
|
||||
class FileChatDisplay(BaseModel):
|
||||
class ImageGenerationDisplay(BaseModel):
|
||||
file_ids: list[str]
|
||||
|
||||
|
||||
@@ -171,44 +165,12 @@ class CustomToolResponse(BaseModel):
|
||||
tool_name: str
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
id: int
|
||||
|
||||
|
||||
class PromptOverrideConfig(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
system_prompt: str
|
||||
task_prompt: str = ""
|
||||
include_citations: bool = True
|
||||
datetime_aware: bool = True
|
||||
|
||||
|
||||
class PersonaOverrideConfig(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
search_type: SearchType = SearchType.SEMANTIC
|
||||
num_chunks: float | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
|
||||
prompts: list[PromptOverrideConfig] = Field(default_factory=list)
|
||||
prompt_ids: list[int] = Field(default_factory=list)
|
||||
|
||||
document_set_ids: list[int] = Field(default_factory=list)
|
||||
tools: list[ToolConfig] = Field(default_factory=list)
|
||||
tool_ids: list[int] = Field(default_factory=list)
|
||||
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
AnswerQuestionPossibleReturn = (
|
||||
DanswerAnswerPiece
|
||||
| DanswerQuotes
|
||||
| CitationInfo
|
||||
| DanswerContexts
|
||||
| FileChatDisplay
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| StreamingError
|
||||
| StreamStopInfo
|
||||
@@ -221,109 +183,3 @@ AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn]
|
||||
class LLMMetricsContainer(BaseModel):
|
||||
prompt_tokens: int
|
||||
response_tokens: int
|
||||
|
||||
|
||||
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
|
||||
|
||||
|
||||
class DocumentPruningConfig(BaseModel):
|
||||
max_chunks: int | None = None
|
||||
max_window_percentage: float | None = None
|
||||
max_tokens: int | None = None
|
||||
# different pruning behavior is expected when the
|
||||
# user manually selects documents they want to chat with
|
||||
# e.g. we don't want to truncate each document to be no more
|
||||
# than one chunk long
|
||||
is_manually_selected_docs: bool = False
|
||||
# If user specifies to include additional context Chunks for each match, then different pruning
|
||||
# is used. As many Sections as possible are included, and the last Section is truncated
|
||||
# If this is false, all of the Sections are truncated if they are longer than the expected Chunk size.
|
||||
# Sections are often expected to be longer than the maximum Chunk size but Chunks should not be.
|
||||
use_sections: bool = True
|
||||
# If using tools, then we need to consider the tool length
|
||||
tool_num_tokens: int = 0
|
||||
# If using a tool message to represent the docs, then we have to JSON serialize
|
||||
# the document content, which adds to the token count.
|
||||
using_tool_message: bool = False
|
||||
|
||||
|
||||
class ContextualPruningConfig(DocumentPruningConfig):
|
||||
num_chunk_multiple: int
|
||||
|
||||
@classmethod
|
||||
def from_doc_pruning_config(
|
||||
cls, num_chunk_multiple: int, doc_pruning_config: DocumentPruningConfig
|
||||
) -> "ContextualPruningConfig":
|
||||
return cls(num_chunk_multiple=num_chunk_multiple, **doc_pruning_config.dict())
|
||||
|
||||
|
||||
class CitationConfig(BaseModel):
|
||||
all_docs_useful: bool = False
|
||||
|
||||
|
||||
class QuotesConfig(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class AnswerStyleConfig(BaseModel):
|
||||
citation_config: CitationConfig | None = None
|
||||
quotes_config: QuotesConfig | None = None
|
||||
document_pruning_config: DocumentPruningConfig = Field(
|
||||
default_factory=DocumentPruningConfig
|
||||
)
|
||||
# forces the LLM to return a structured response, see
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
# right now, only used by the simple chat API
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
|
||||
if self.citation_config is None and self.quotes_config is None:
|
||||
raise ValueError(
|
||||
"One of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
if self.citation_config is not None and self.quotes_config is not None:
|
||||
raise ValueError(
|
||||
"Only one of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""Final representation of the Prompt configuration passed
|
||||
into the `Answer` object."""
|
||||
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
datetime_aware: bool
|
||||
include_citations: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, model: "Prompt", prompt_override: PromptOverride | None = None
|
||||
) -> "PromptConfig":
|
||||
override_system_prompt = (
|
||||
prompt_override.system_prompt if prompt_override else None
|
||||
)
|
||||
override_task_prompt = prompt_override.task_prompt if prompt_override else None
|
||||
|
||||
return cls(
|
||||
system_prompt=override_system_prompt or model.system_prompt,
|
||||
task_prompt=override_task_prompt or model.task_prompt,
|
||||
datetime_aware=model.datetime_aware,
|
||||
include_citations=model.include_citations,
|
||||
)
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
||||
ResponsePart = (
|
||||
DanswerAnswerPiece
|
||||
| CitationInfo
|
||||
| ToolCallKickoff
|
||||
| ToolResponse
|
||||
| ToolCallFinalResult
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ personas:
|
||||
# this is for DanswerBot to use when tagged in a non-configured channel
|
||||
# Careful setting specific IDs, this won't autoincrement the next ID value for postgres
|
||||
- id: 0
|
||||
name: "Search"
|
||||
name: "Knowledge"
|
||||
description: >
|
||||
Assistant with access to documents from your Connected Sources.
|
||||
# Default Prompt objects attached to the persona, see prompts.yaml
|
||||
@@ -42,14 +42,18 @@ personas:
|
||||
display_priority: 1
|
||||
is_visible: true
|
||||
starter_messages:
|
||||
- name: "Give me an overview of what's here"
|
||||
message: "Sample some documents and tell me what you find."
|
||||
- name: "Use AI to solve a work related problem"
|
||||
message: "Ask me what problem I would like to solve, then search the knowledge base to help me find a solution."
|
||||
- name: "Find updates on a topic of interest"
|
||||
message: "Once I provide a topic, retrieve related documents and tell me when there was last activity on the topic if available."
|
||||
- name: "Surface contradictions"
|
||||
message: "Have me choose a subject. Once I have provided it, check against the knowledge base and point out any inconsistencies. For all your following responses, focus on identifying contradictions."
|
||||
- name: "General Information"
|
||||
description: "Ask about available information"
|
||||
message: "Hello! I'm interested in learning more about the information available here. Could you give me an overview of the types of data or documents that might be accessible?"
|
||||
- name: "Specific Topic Search"
|
||||
description: "Search for specific information"
|
||||
message: "Hi! I'd like to learn more about a specific topic. Could you help me find relevant documents and information?"
|
||||
- name: "Recent Updates"
|
||||
description: "Inquire about latest additions"
|
||||
message: "Hello! I'm curious about any recent updates or additions to the knowledge base. Can you tell me what new information has been added lately?"
|
||||
- name: "Cross-referencing Information"
|
||||
description: "Connect information from different sources"
|
||||
message: "Hi! I'm working on a project that requires connecting information from multiple sources. How can I effectively cross-reference data across different documents or categories?"
|
||||
|
||||
- id: 1
|
||||
name: "General"
|
||||
@@ -67,14 +71,18 @@ personas:
|
||||
display_priority: 0
|
||||
is_visible: true
|
||||
starter_messages:
|
||||
- name: "Summarize a document"
|
||||
message: "If I have provided a document please summarize it for me. If not, please ask me to upload a document either by dragging it into the input bar or clicking the +file icon."
|
||||
- name: "Help me with coding"
|
||||
message: 'Write me a "Hello World" script in 5 random languages to show off the functionality.'
|
||||
- name: "Draft a professional email"
|
||||
message: "Help me craft a professional email. Let's establish the context and the anticipated outcomes of the email before proposing a draft."
|
||||
- name: "Learn something new"
|
||||
message: "What is the difference between a Gantt chart, a Burndown chart and a Kanban board?"
|
||||
- name: "Open Discussion"
|
||||
description: "Start an open-ended conversation"
|
||||
message: "Hi! Can you help me write a professional email?"
|
||||
- name: "Problem Solving"
|
||||
description: "Get help with a challenge"
|
||||
message: "Hello! I need help managing my daily tasks better. Do you have any simple tips?"
|
||||
- name: "Learn Something New"
|
||||
description: "Explore a new topic"
|
||||
message: "Hi! Could you explain what project management is in simple terms?"
|
||||
- name: "Creative Brainstorming"
|
||||
description: "Generate creative ideas"
|
||||
message: "Hello! I need to brainstorm some team building activities. Do you have any fun suggestions?"
|
||||
|
||||
- id: 2
|
||||
name: "Paraphrase"
|
||||
@@ -93,12 +101,16 @@ personas:
|
||||
is_visible: false
|
||||
starter_messages:
|
||||
- name: "Document Search"
|
||||
description: "Find exact information"
|
||||
message: "Hi! Could you help me find information about our team structure and reporting lines from our internal documents?"
|
||||
- name: "Process Verification"
|
||||
description: "Find exact quotes"
|
||||
message: "Hello! I need to understand our project approval process. Could you find the exact steps from our documentation?"
|
||||
- name: "Technical Documentation"
|
||||
description: "Search technical details"
|
||||
message: "Hi there! I'm looking for information about our deployment procedures. Can you find the specific steps from our technical guides?"
|
||||
- name: "Policy Reference"
|
||||
description: "Check official policies"
|
||||
message: "Hello! Could you help me find our official guidelines about client communication? I need the exact wording from our documentation."
|
||||
|
||||
- id: 3
|
||||
@@ -118,11 +130,15 @@ personas:
|
||||
display_priority: 3
|
||||
is_visible: true
|
||||
starter_messages:
|
||||
- name: "Create visuals for a presentation"
|
||||
message: "Generate someone presenting a graph which clearly demonstrates an upwards trajectory."
|
||||
- name: "Find inspiration for a marketing campaign"
|
||||
message: "Generate an image of two happy individuals sipping on a soda drink in a glass bottle."
|
||||
- name: "Visualize a product design"
|
||||
message: "I want to add a search bar to my Iphone app. Generate me generic examples of how other apps implement this."
|
||||
- name: "Generate a humorous image response"
|
||||
message: "My teammate just made a silly mistake and I want to respond with a facepalm. Can you generate me one?"
|
||||
- name: "Landscape"
|
||||
description: "Generate a landscape image"
|
||||
message: "Create an image of a serene mountain lake at sunset, with snow-capped peaks reflected in the calm water and a small wooden cabin on the shore."
|
||||
- name: "Character"
|
||||
description: "Generate a character image"
|
||||
message: "Generate an image of a futuristic robot with glowing blue eyes, sleek metallic body, and intricate circuitry visible through transparent panels on its chest and arms."
|
||||
- name: "Abstract"
|
||||
description: "Create an abstract image"
|
||||
message: "Create an abstract image representing the concept of time, using swirling clock hands, fragmented hourglasses, and streaks of light to convey the passage of moments and eras."
|
||||
- name: "Urban Scene"
|
||||
description: "Generate an urban landscape"
|
||||
message: "Generate an image of a bustling futuristic cityscape at night, with towering skyscrapers, flying vehicles, holographic advertisements, and a mix of neon and bioluminescent lighting."
|
||||
@@ -6,41 +6,28 @@ from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.answer import Answer
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.chat_utils import create_temporary_persona
|
||||
from danswer.chat.models import AllCitations
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import ChatDanswerBotResponse
|
||||
from danswer.chat.models import CitationConfig
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import CustomToolResponse
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import DocumentPruningConfig
|
||||
from danswer.chat.models import FileChatDisplay
|
||||
from danswer.chat.models import FinalUsedContextDocsResponse
|
||||
from danswer.chat.models import ImageGenerationDisplay
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import MessageResponseIDInfo
|
||||
from danswer.chat.models import MessageSpecificCitations
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
|
||||
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.context.search.enums import OptionalSearchSetting
|
||||
from danswer.context.search.enums import QueryFlow
|
||||
from danswer.context.search.enums import SearchType
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.context.search.retrieval.search_runner import inference_sections_from_ids
|
||||
from danswer.context.search.utils import chunks_or_sections_to_search_docs
|
||||
from danswer.context.search.utils import dedupe_documents
|
||||
from danswer.context.search.utils import drop_llm_indices
|
||||
from danswer.context.search.utils import relevant_sections_to_indices
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.db.chat import attach_files_to_chat_message
|
||||
from danswer.db.chat import create_db_search_doc
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
@@ -53,6 +40,7 @@ from danswer.db.chat import reserve_message_id
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.models import SearchDoc as DbSearchDoc
|
||||
from danswer.db.models import ToolCall
|
||||
from danswer.db.models import User
|
||||
@@ -62,65 +50,64 @@ from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from danswer.file_store.utils import load_all_chat_files
|
||||
from danswer.file_store.utils import save_files
|
||||
from danswer.file_store.utils import save_files_from_urls
|
||||
from danswer.llm.answering.answer import Answer
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import litellm_exception_to_error_msg
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.retrieval.search_runner import inference_sections_from_ids
|
||||
from danswer.search.utils import chunks_or_sections_to_search_docs
|
||||
from danswer.search.utils import dedupe_documents
|
||||
from danswer.search.utils import drop_llm_indices
|
||||
from danswer.search.utils import relevant_sections_to_indices
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.built_in_tools import get_built_in_tool_by_id
|
||||
from danswer.tools.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
from danswer.tools.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool_constructor import construct_tools
|
||||
from danswer.tools.tool_constructor import CustomToolConfig
|
||||
from danswer.tools.tool_constructor import ImageGenerationToolConfig
|
||||
from danswer.tools.tool_constructor import InternetSearchToolConfig
|
||||
from danswer.tools.tool_constructor import SearchToolConfig
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
CUSTOM_TOOL_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
IMAGE_GENERATION_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationResponse,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
from danswer.tools.internet_search.internet_search_tool import (
|
||||
INTERNET_SEARCH_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
from danswer.tools.internet_search.internet_search_tool import (
|
||||
internet_search_response_to_search_docs,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchResponse,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchTool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SECTION_RELEVANCE_LIST_ID,
|
||||
)
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
|
||||
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
||||
from danswer.tools.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.tools.utils import compute_all_tool_tokens
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.headers import header_dict_to_header_list
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.long_term_log import LongTermLogger
|
||||
from danswer.utils.timing import log_function_time
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -263,18 +250,16 @@ def _get_force_search_settings(
|
||||
ChatPacket = (
|
||||
StreamingError
|
||||
| QADocsResponse
|
||||
| DanswerContexts
|
||||
| LLMRelevanceFilterResponse
|
||||
| FinalUsedContextDocsResponse
|
||||
| ChatMessageDetail
|
||||
| DanswerAnswerPiece
|
||||
| AllCitations
|
||||
| CitationInfo
|
||||
| FileChatDisplay
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| MessageSpecificCitations
|
||||
| MessageResponseIDInfo
|
||||
| StreamStopInfo
|
||||
)
|
||||
ChatPacketStream = Iterator[ChatPacket]
|
||||
|
||||
@@ -290,12 +275,11 @@ def stream_chat_message_objects(
|
||||
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||
# if specified, uses the last user message and does not create a new user message based
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
enforce_chat_session_id_for_search_docs: bool = True,
|
||||
bypass_acl: bool = False,
|
||||
include_contexts: bool = False,
|
||||
) -> ChatPacketStream:
|
||||
"""Streams in order:
|
||||
1. [conditional] Retrieved documents if a search needs to be run
|
||||
@@ -303,10 +287,6 @@ def stream_chat_message_objects(
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
"""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
use_existing_user_message = new_msg_req.use_existing_user_message
|
||||
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
|
||||
|
||||
# Currently surrounding context is not supported for chat
|
||||
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
|
||||
new_msg_req.chunks_above = 0
|
||||
@@ -328,36 +308,17 @@ def stream_chat_message_objects(
|
||||
retrieval_options = new_msg_req.retrieval_options
|
||||
alternate_assistant_id = new_msg_req.alternate_assistant_id
|
||||
|
||||
# permanent "log" store, used primarily for debugging
|
||||
long_term_logger = LongTermLogger(
|
||||
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
|
||||
)
|
||||
|
||||
# use alternate persona if alternative assistant id is passed in
|
||||
if alternate_assistant_id is not None:
|
||||
# Allows users to specify a temporary persona (assistant) in the chat session
|
||||
# this takes highest priority since it's user specified
|
||||
persona = get_persona_by_id(
|
||||
alternate_assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
elif new_msg_req.persona_override_config:
|
||||
# Certain endpoints allow users to specify arbitrary persona settings
|
||||
# this should never conflict with the alternate_assistant_id
|
||||
persona = persona = create_temporary_persona(
|
||||
db_session=db_session,
|
||||
persona_config=new_msg_req.persona_override_config,
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
persona = chat_session.persona
|
||||
|
||||
if not persona:
|
||||
raise RuntimeError("No persona specified or found for chat session")
|
||||
|
||||
# If a prompt override is specified via the API, use that with highest priority
|
||||
# but for saving it, we are just mapping it to an existing prompt
|
||||
prompt_id = new_msg_req.prompt_id
|
||||
if prompt_id is None and persona.prompts:
|
||||
prompt_id = sorted(persona.prompts, key=lambda x: x.id)[-1].id
|
||||
@@ -372,7 +333,6 @@ def stream_chat_message_objects(
|
||||
persona=persona,
|
||||
llm_override=new_msg_req.llm_override or chat_session.llm_override,
|
||||
additional_headers=litellm_additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
)
|
||||
except GenAIDisabledException:
|
||||
raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.")
|
||||
@@ -448,20 +408,12 @@ def stream_chat_message_objects(
|
||||
final_msg, history_msgs = create_chat_chain(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
if existing_assistant_message_id is None:
|
||||
if final_msg.message_type != MessageType.USER:
|
||||
raise RuntimeError(
|
||||
"The last message was not a user message. Cannot call "
|
||||
"`stream_chat_message_objects` with `is_regenerate=True` "
|
||||
"when the last message is not a user message."
|
||||
)
|
||||
else:
|
||||
if final_msg.id != existing_assistant_message_id:
|
||||
raise RuntimeError(
|
||||
"The last message was not the existing assistant message. "
|
||||
f"Final message id: {final_msg.id}, "
|
||||
f"existing assistant message id: {existing_assistant_message_id}"
|
||||
)
|
||||
if final_msg.message_type != MessageType.USER:
|
||||
raise RuntimeError(
|
||||
"The last message was not a user message. Cannot call "
|
||||
"`stream_chat_message_objects` with `is_regenerate=True` "
|
||||
"when the last message is not a user message."
|
||||
)
|
||||
|
||||
# Disable Query Rephrasing for the first message
|
||||
# This leads to a better first response since the LLM rephrasing the question
|
||||
@@ -532,19 +484,13 @@ def stream_chat_message_objects(
|
||||
),
|
||||
max_window_percentage=max_document_percentage,
|
||||
)
|
||||
|
||||
# we don't need to reserve a message id if we're using an existing assistant message
|
||||
reserved_message_id = (
|
||||
final_msg.id
|
||||
if existing_assistant_message_id is not None
|
||||
else reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=user_message.id
|
||||
if user_message is not None
|
||||
else parent_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
reserved_message_id = reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=user_message.id
|
||||
if user_message is not None
|
||||
else parent_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
yield MessageResponseIDInfo(
|
||||
user_message_id=user_message.id if user_message else None,
|
||||
@@ -559,13 +505,7 @@ def stream_chat_message_objects(
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
chat_session_id=chat_session_id,
|
||||
# if we're using an existing assistant message, then this will just be an
|
||||
# update operation, in which case the parent should be the parent of
|
||||
# the latest. If we're creating a new assistant message, then the parent
|
||||
# should be the latest message (latest user message)
|
||||
parent_message=(
|
||||
final_msg if existing_assistant_message_id is None else parent_message
|
||||
),
|
||||
parent_message=final_msg,
|
||||
prompt_id=prompt_id,
|
||||
overridden_model=overridden_model,
|
||||
# message=,
|
||||
@@ -577,87 +517,163 @@ def stream_chat_message_objects(
|
||||
# reference_docs=,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
reserved_message_id=reserved_message_id,
|
||||
)
|
||||
|
||||
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
|
||||
if new_msg_req.persona_override_config:
|
||||
prompt_config = PromptConfig(
|
||||
system_prompt=new_msg_req.persona_override_config.prompts[
|
||||
0
|
||||
].system_prompt,
|
||||
task_prompt=new_msg_req.persona_override_config.prompts[0].task_prompt,
|
||||
datetime_aware=new_msg_req.persona_override_config.prompts[
|
||||
0
|
||||
].datetime_aware,
|
||||
include_citations=new_msg_req.persona_override_config.prompts[
|
||||
0
|
||||
].include_citations,
|
||||
)
|
||||
elif prompt_override:
|
||||
if not final_msg.prompt:
|
||||
raise ValueError(
|
||||
"Prompt override cannot be applied, no base prompt found."
|
||||
)
|
||||
prompt_config = PromptConfig.from_model(
|
||||
if not final_msg.prompt:
|
||||
raise RuntimeError("No Prompt found")
|
||||
|
||||
prompt_config = (
|
||||
PromptConfig.from_model(
|
||||
final_msg.prompt,
|
||||
prompt_override=prompt_override,
|
||||
prompt_override=(
|
||||
new_msg_req.prompt_override or chat_session.prompt_override
|
||||
),
|
||||
)
|
||||
elif final_msg.prompt:
|
||||
prompt_config = PromptConfig.from_model(final_msg.prompt)
|
||||
else:
|
||||
prompt_config = PromptConfig.from_model(persona.prompts[0])
|
||||
|
||||
answer_style_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig(
|
||||
all_docs_useful=selected_db_search_docs is not None
|
||||
),
|
||||
document_pruning_config=document_pruning_config,
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
if not persona
|
||||
else PromptConfig.from_model(persona.prompts[0])
|
||||
)
|
||||
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
prompt_config=prompt_config,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
retrieval_options=retrieval_options or RetrievalDetails(),
|
||||
rerank_settings=new_msg_req.rerank_settings,
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
latest_query_files=latest_query_files,
|
||||
bypass_acl=bypass_acl,
|
||||
),
|
||||
internet_search_tool_config=InternetSearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
),
|
||||
image_generation_tool_config=ImageGenerationToolConfig(
|
||||
additional_headers=litellm_additional_headers,
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
),
|
||||
)
|
||||
# find out what tools to use
|
||||
search_tool: SearchTool | None = None
|
||||
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
|
||||
for db_tool_model in persona.tools:
|
||||
# handle in-code tools specially
|
||||
if db_tool_model.in_code_tool_id:
|
||||
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
|
||||
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
persona=persona,
|
||||
retrieval_options=retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
evaluation_type=LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
img_generation_llm_config: LLMConfig | None = None
|
||||
if (
|
||||
llm
|
||||
and llm.config.api_key
|
||||
and llm.config.model_provider == "openai"
|
||||
):
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider=llm.config.model_provider,
|
||||
model_name="dall-e-3",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=llm.config.api_key,
|
||||
api_base=llm.config.api_base,
|
||||
api_version=llm.config.api_version,
|
||||
)
|
||||
elif (
|
||||
llm.config.model_provider == "azure"
|
||||
and AZURE_DALLE_API_KEY is not None
|
||||
):
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider="azure",
|
||||
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=AZURE_DALLE_API_KEY,
|
||||
api_base=AZURE_DALLE_API_BASE,
|
||||
api_version=AZURE_DALLE_API_VERSION,
|
||||
)
|
||||
else:
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
openai_provider = next(
|
||||
iter(
|
||||
[
|
||||
llm_provider
|
||||
for llm_provider in llm_providers
|
||||
if llm_provider.provider == "openai"
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not openai_provider or not openai_provider.api_key:
|
||||
raise ValueError(
|
||||
"Image generation tool requires an OpenAI API key"
|
||||
)
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider=openai_provider.provider,
|
||||
model_name="dall-e-3",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=openai_provider.api_key,
|
||||
api_base=openai_provider.api_base,
|
||||
api_version=openai_provider.api_version,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
ImageGenerationTool(
|
||||
api_key=cast(str, img_generation_llm_config.api_key),
|
||||
api_base=img_generation_llm_config.api_base,
|
||||
api_version=img_generation_llm_config.api_version,
|
||||
additional_headers=litellm_additional_headers,
|
||||
model=img_generation_llm_config.model_name,
|
||||
)
|
||||
]
|
||||
elif tool_cls.__name__ == InternetSearchTool.__name__:
|
||||
bing_api_key = BING_API_KEY
|
||||
if not bing_api_key:
|
||||
raise ValueError(
|
||||
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
InternetSearchTool(api_key=bing_api_key)
|
||||
]
|
||||
|
||||
continue
|
||||
|
||||
# handle all custom tools
|
||||
if db_tool_model.openapi_schema:
|
||||
tool_dict[db_tool_model.id] = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema_and_headers(
|
||||
db_tool_model.openapi_schema,
|
||||
dynamic_schema_info=DynamicSchemaInfo(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
),
|
||||
custom_headers=(db_tool_model.custom_headers or [])
|
||||
+ (
|
||||
header_dict_to_header_list(
|
||||
custom_tool_additional_headers or {}
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
|
||||
tools, llm_tokenizer
|
||||
)
|
||||
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
|
||||
llm_provider, llm_model_name
|
||||
)
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
answer = Answer(
|
||||
is_connected=is_connected,
|
||||
question=final_msg.message,
|
||||
latest_query_files=latest_query_files,
|
||||
answer_style_config=answer_style_config,
|
||||
answer_style_config=AnswerStyleConfig(
|
||||
citation_config=CitationConfig(
|
||||
all_docs_useful=selected_db_search_docs is not None
|
||||
),
|
||||
document_pruning_config=document_pruning_config,
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
),
|
||||
prompt_config=prompt_config,
|
||||
llm=(
|
||||
llm
|
||||
@@ -680,8 +696,7 @@ def stream_chat_message_objects(
|
||||
|
||||
reference_db_search_docs = None
|
||||
qa_docs_response = None
|
||||
# any files to associate with the AI message e.g. dall-e generated images
|
||||
ai_message_files = []
|
||||
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
|
||||
dropped_indices = None
|
||||
tool_result = None
|
||||
|
||||
@@ -726,6 +741,7 @@ def stream_chat_message_objects(
|
||||
yield LLMRelevanceFilterResponse(
|
||||
llm_selected_doc_indices=llm_indices
|
||||
)
|
||||
|
||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
yield FinalUsedContextDocsResponse(
|
||||
final_context_docs=packet.response
|
||||
@@ -736,20 +752,14 @@ def stream_chat_message_objects(
|
||||
list[ImageGenerationResponse], packet.response
|
||||
)
|
||||
|
||||
file_ids = save_files(
|
||||
urls=[img.url for img in img_generation_response if img.url],
|
||||
base64_files=[
|
||||
img.image_data
|
||||
for img in img_generation_response
|
||||
if img.image_data
|
||||
],
|
||||
tenant_id=tenant_id,
|
||||
file_ids = save_files_from_urls(
|
||||
[img.url for img in img_generation_response]
|
||||
)
|
||||
ai_message_files = [
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
for file_id in file_ids
|
||||
]
|
||||
yield FileChatDisplay(
|
||||
yield ImageGenerationDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
@@ -763,38 +773,11 @@ def stream_chat_message_objects(
|
||||
yield qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
|
||||
if (
|
||||
custom_tool_response.response_type == "image"
|
||||
or custom_tool_response.response_type == "csv"
|
||||
):
|
||||
file_ids = custom_tool_response.tool_result.file_ids
|
||||
ai_message_files.extend(
|
||||
[
|
||||
FileDescriptor(
|
||||
id=str(file_id),
|
||||
type=(
|
||||
ChatFileType.IMAGE
|
||||
if custom_tool_response.response_type == "image"
|
||||
else ChatFileType.CSV
|
||||
),
|
||||
)
|
||||
for file_id in file_ids
|
||||
]
|
||||
)
|
||||
yield FileChatDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
else:
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID and include_contexts:
|
||||
yield cast(DanswerContexts, packet.response)
|
||||
|
||||
elif isinstance(packet, StreamStopInfo):
|
||||
pass
|
||||
else:
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
tool_result = packet
|
||||
@@ -824,15 +807,13 @@ def stream_chat_message_objects(
|
||||
|
||||
# Post-LLM answer processing
|
||||
try:
|
||||
logger.debug("Post-LLM answer processing")
|
||||
message_specific_citations: MessageSpecificCitations | None = None
|
||||
if reference_db_search_docs:
|
||||
message_specific_citations = _translate_citations(
|
||||
citations_list=answer.citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
)
|
||||
if not answer.is_cancelled():
|
||||
yield AllCitations(citations=answer.citations)
|
||||
yield AllCitations(citations=answer.citations)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
@@ -841,6 +822,7 @@ def stream_chat_message_objects(
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=(
|
||||
qa_docs_response.rephrased_query if qa_docs_response else None
|
||||
@@ -848,21 +830,21 @@ def stream_chat_message_objects(
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=(
|
||||
message_specific_citations.citation_map
|
||||
if message_specific_citations
|
||||
else None
|
||||
),
|
||||
citations=message_specific_citations.citation_map
|
||||
if message_specific_citations
|
||||
else None,
|
||||
error=None,
|
||||
tool_call=(
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
tool_calls=(
|
||||
[
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
]
|
||||
if tool_result
|
||||
else None
|
||||
else []
|
||||
),
|
||||
)
|
||||
|
||||
@@ -886,6 +868,7 @@ def stream_chat_message_objects(
|
||||
def stream_chat_message(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
@@ -895,36 +878,10 @@ def stream_chat_message(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
custom_tool_additional_headers=custom_tool_additional_headers,
|
||||
is_connected=is_connected,
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.model_dump())
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def gather_stream_for_slack(
|
||||
packets: ChatPacketStream,
|
||||
) -> ChatDanswerBotResponse:
|
||||
response = ChatDanswerBotResponse()
|
||||
|
||||
answer = ""
|
||||
for packet in packets:
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
response.docs = packet
|
||||
elif isinstance(packet, StreamingError):
|
||||
response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
response.chat_message_id = packet.message_id
|
||||
elif isinstance(packet, LLMRelevanceFilterResponse):
|
||||
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
||||
elif isinstance(packet, AllCitations):
|
||||
response.citations = packet.citations
|
||||
|
||||
if answer:
|
||||
response.answer = answer
|
||||
|
||||
return response
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
from langchain.schema.messages import HumanMessage
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.configs.chat_configs import LANGUAGE_HINT
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
from danswer.db.search_settings import get_multilingual_expansion
|
||||
from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
|
||||
from danswer.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from danswer.prompts.prompt_utils import build_complete_context_str
|
||||
|
||||
|
||||
def _build_strong_llm_quotes_prompt(
|
||||
question: str,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
history_str: str,
|
||||
prompt: PromptConfig,
|
||||
) -> HumanMessage:
|
||||
use_language_hint = bool(get_multilingual_expansion())
|
||||
|
||||
context_block = ""
|
||||
if context_docs:
|
||||
context_docs_str = build_complete_context_str(context_docs)
|
||||
context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs_str)
|
||||
|
||||
history_block = ""
|
||||
if history_str:
|
||||
history_block = HISTORY_BLOCK.format(history_str=history_str)
|
||||
|
||||
full_prompt = JSON_PROMPT.format(
|
||||
system_prompt=prompt.system_prompt,
|
||||
context_block=context_block,
|
||||
history_block=history_block,
|
||||
task_prompt=prompt.task_prompt,
|
||||
user_query=question,
|
||||
language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "",
|
||||
).strip()
|
||||
|
||||
if prompt.datetime_aware:
|
||||
full_prompt = add_date_time_to_prompt(prompt_str=full_prompt)
|
||||
|
||||
return HumanMessage(content=full_prompt)
|
||||
|
||||
|
||||
def build_quotes_user_message(
|
||||
message: HumanMessage,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
history_str: str,
|
||||
prompt: PromptConfig,
|
||||
) -> HumanMessage:
|
||||
query, _ = message_to_prompt_and_imgs(message)
|
||||
|
||||
return _build_strong_llm_quotes_prompt(
|
||||
question=query,
|
||||
context_docs=context_docs,
|
||||
history_str=history_str,
|
||||
prompt=prompt,
|
||||
)
|
||||
@@ -1,62 +0,0 @@
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
|
||||
|
||||
|
||||
def build_dummy_prompt(
|
||||
system_prompt: str, task_prompt: str, retrieval_disabled: bool
|
||||
) -> str:
|
||||
if retrieval_disabled:
|
||||
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
return PARAMATERIZED_PROMPT.format(
|
||||
context_docs_str="<CONTEXT_DOCS>",
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
|
||||
def translate_danswer_msg_to_langchain(
|
||||
msg: ChatMessage | PreviousMessage,
|
||||
) -> BaseMessage:
|
||||
files: list[InMemoryChatFile] = []
|
||||
|
||||
# If the message is a `ChatMessage`, it doesn't have the downloaded files
|
||||
# attached. Just ignore them for now.
|
||||
if not isinstance(msg, ChatMessage):
|
||||
files = msg.files
|
||||
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
|
||||
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
raise ValueError("System messages are not currently part of history")
|
||||
if msg.message_type == MessageType.ASSISTANT:
|
||||
return AIMessage(content=content)
|
||||
if msg.message_type == MessageType.USER:
|
||||
return HumanMessage(content=content)
|
||||
|
||||
raise ValueError(f"New message type {msg.message_type} not handled")
|
||||
|
||||
|
||||
def translate_history_to_basemessages(
|
||||
history: list[ChatMessage] | list["PreviousMessage"],
|
||||
) -> tuple[list[BaseMessage], list[int]]:
|
||||
history_basemessages = [
|
||||
translate_danswer_msg_to_langchain(msg)
|
||||
for msg in history
|
||||
if msg.token_count != 0
|
||||
]
|
||||
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
|
||||
return history_basemessages, history_token_counts
|
||||
@@ -9,19 +9,19 @@ prompts:
|
||||
system: >
|
||||
You are a question answering system that is constantly learning and improving.
|
||||
The current date is DANSWER_DATETIME_REPLACEMENT.
|
||||
|
||||
|
||||
You can process and comprehend vast amounts of text and utilize this knowledge to provide
|
||||
grounded, accurate, and concise answers to diverse queries.
|
||||
|
||||
|
||||
You always clearly communicate ANY UNCERTAINTY in your answer.
|
||||
# Task Prompt (as shown in UI)
|
||||
task: >
|
||||
Answer my query based on the documents provided.
|
||||
The documents may not all be relevant, ignore any documents that are not directly relevant
|
||||
to the most recent user query.
|
||||
|
||||
|
||||
I have not read or seen any of the documents and do not want to read them.
|
||||
|
||||
|
||||
If there are no relevant documents, refer to the chat history and your internal knowledge.
|
||||
# Inject a statement at the end of system prompt to inform the LLM of the current date/time
|
||||
# If the DANSWER_DATETIME_REPLACEMENT is set, the date/time is inserted there instead
|
||||
@@ -30,21 +30,21 @@ prompts:
|
||||
# Prompts the LLM to include citations in the for [1], [2] etc.
|
||||
# which get parsed to match the passed in sources
|
||||
include_citations: true
|
||||
|
||||
|
||||
- name: "ImageGeneration"
|
||||
description: "Generates images from user descriptions!"
|
||||
description: "Generates images based on user prompts!"
|
||||
system: >
|
||||
You are an AI image generation assistant. Your role is to create high-quality images based on user descriptions.
|
||||
|
||||
For appropriate requests, you will generate an image that matches the user's requirements.
|
||||
For inappropriate or unsafe requests, you will politely decline and explain why the request cannot be fulfilled.
|
||||
|
||||
You aim to be helpful while maintaining appropriate content standards.
|
||||
You are an advanced image generation system capable of creating diverse and detailed images.
|
||||
|
||||
You can interpret user prompts and generate high-quality, creative images that match their descriptions.
|
||||
|
||||
You always strive to create safe and appropriate content, avoiding any harmful or offensive imagery.
|
||||
task: >
|
||||
Based on the user's description, create a high-quality image that accurately reflects their request.
|
||||
Pay close attention to the specified details, styles, and desired elements.
|
||||
|
||||
If the request is not appropriate or cannot be fulfilled, explain why and suggest alternatives.
|
||||
Generate an image based on the user's description.
|
||||
|
||||
Provide a detailed description of the generated image, including key elements, colors, and composition.
|
||||
|
||||
If the request is not possible or appropriate, explain why and suggest alternatives.
|
||||
datetime_aware: true
|
||||
include_citations: false
|
||||
|
||||
@@ -64,13 +64,14 @@ prompts:
|
||||
datetime_aware: true
|
||||
include_citations: true
|
||||
|
||||
|
||||
- name: "Summarize"
|
||||
description: "Summarize relevant information from retrieved context!"
|
||||
system: >
|
||||
You are a text summarizing assistant that highlights the most important knowledge from the
|
||||
context provided, prioritizing the information that relates to the user query.
|
||||
The current date is DANSWER_DATETIME_REPLACEMENT.
|
||||
|
||||
|
||||
You ARE NOT creative and always stick to the provided documents.
|
||||
If there are no documents, refer to the conversation history.
|
||||
|
||||
@@ -83,6 +84,7 @@ prompts:
|
||||
datetime_aware: true
|
||||
include_citations: true
|
||||
|
||||
|
||||
- name: "Paraphrase"
|
||||
description: "Recites information from retrieved context! Least creative but most safe!"
|
||||
system: >
|
||||
@@ -90,10 +92,10 @@ prompts:
|
||||
The current date is DANSWER_DATETIME_REPLACEMENT.
|
||||
|
||||
You only provide quotes that are EXACT substrings from provided documents!
|
||||
|
||||
|
||||
If there are no documents provided,
|
||||
simply tell the user that there are no documents to reference.
|
||||
|
||||
|
||||
You NEVER generate new text or phrases outside of the citation.
|
||||
DO NOT explain your responses, only provide the quotes and NOTHING ELSE.
|
||||
task: >
|
||||
@@ -1,93 +0,0 @@
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from danswer.chat.llm_response_handler import ResponsePart
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.stream_processing.citation_processing import CitationProcessor
|
||||
from danswer.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class AnswerResponseHandler(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DummyAnswerResponseHandler(AnswerResponseHandler):
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
# This is a dummy handler that returns nothing
|
||||
yield from []
|
||||
|
||||
|
||||
class CitationResponseHandler(AnswerResponseHandler):
|
||||
def __init__(
|
||||
self, context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.citation_processor = CitationProcessor(
|
||||
context_docs=self.context_docs,
|
||||
doc_id_to_rank_map=self.doc_id_to_rank_map,
|
||||
)
|
||||
self.processed_text = ""
|
||||
self.citations: list[CitationInfo] = []
|
||||
|
||||
# TODO remove this after citation issue is resolved
|
||||
logger.debug(f"Document to ranking map {self.doc_id_to_rank_map}")
|
||||
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
if response_item is None:
|
||||
return
|
||||
|
||||
content = (
|
||||
response_item.content if isinstance(response_item.content, str) else ""
|
||||
)
|
||||
|
||||
# Process the new content through the citation processor
|
||||
yield from self.citation_processor.process_token(content)
|
||||
|
||||
|
||||
# No longer in use, remove later
|
||||
# class QuotesResponseHandler(AnswerResponseHandler):
|
||||
# def __init__(
|
||||
# self,
|
||||
# context_docs: list[LlmDoc],
|
||||
# is_json_prompt: bool = True,
|
||||
# ):
|
||||
# self.quotes_processor = QuotesProcessor(
|
||||
# context_docs=context_docs,
|
||||
# is_json_prompt=is_json_prompt,
|
||||
# )
|
||||
|
||||
# def handle_response_part(
|
||||
# self,
|
||||
# response_item: BaseMessage | None,
|
||||
# previous_response_items: list[BaseMessage],
|
||||
# ) -> Generator[ResponsePart, None, None]:
|
||||
# if response_item is None:
|
||||
# yield from self.quotes_processor.process_token(None)
|
||||
# return
|
||||
|
||||
# content = (
|
||||
# response_item.content if isinstance(response_item.content, str) else ""
|
||||
# )
|
||||
|
||||
# yield from self.quotes_processor.process_token(content)
|
||||
@@ -1,175 +0,0 @@
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.configs.chat_configs import STOP_STREAM_PAT
|
||||
from danswer.prompts.constants import TRIPLE_BACKTICK
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def in_code_block(llm_text: str) -> bool:
|
||||
count = llm_text.count(TRIPLE_BACKTICK)
|
||||
return count % 2 != 0
|
||||
|
||||
|
||||
class CitationProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.stop_stream = stop_stream
|
||||
self.order_mapping = doc_id_to_rank_map.order_mapping
|
||||
self.llm_out = ""
|
||||
self.max_citation_num = len(context_docs)
|
||||
self.citation_order: list[int] = []
|
||||
self.curr_segment = ""
|
||||
self.cited_inds: set[int] = set()
|
||||
self.hold = ""
|
||||
self.current_citations: list[int] = []
|
||||
self.past_cite_count = 0
|
||||
|
||||
def process_token(
|
||||
self, token: str | None
|
||||
) -> Generator[DanswerAnswerPiece | CitationInfo, None, None]:
|
||||
# None -> end of stream
|
||||
if token is None:
|
||||
yield DanswerAnswerPiece(answer_piece=self.curr_segment)
|
||||
return
|
||||
|
||||
if self.stop_stream:
|
||||
next_hold = self.hold + token
|
||||
if self.stop_stream in next_hold:
|
||||
return
|
||||
if next_hold == self.stop_stream[: len(next_hold)]:
|
||||
self.hold = next_hold
|
||||
return
|
||||
token = next_hold
|
||||
self.hold = ""
|
||||
|
||||
self.curr_segment += token
|
||||
self.llm_out += token
|
||||
|
||||
# Handle code blocks without language tags
|
||||
if "`" in self.curr_segment:
|
||||
if self.curr_segment.endswith("`"):
|
||||
return
|
||||
elif "```" in self.curr_segment:
|
||||
piece_that_comes_after = self.curr_segment.split("```")[1][0]
|
||||
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
|
||||
self.curr_segment = self.curr_segment.replace("```", "```plaintext")
|
||||
|
||||
citation_pattern = r"\[(\d+)\]|\[\[(\d+)\]\]" # [1], [[1]], etc.
|
||||
citations_found = list(re.finditer(citation_pattern, self.curr_segment))
|
||||
possible_citation_pattern = r"(\[+\d*$)" # [1, [, [[, [[2, etc.
|
||||
possible_citation_found = re.search(
|
||||
possible_citation_pattern, self.curr_segment
|
||||
)
|
||||
|
||||
if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5:
|
||||
self.current_citations = []
|
||||
|
||||
result = ""
|
||||
if citations_found and not in_code_block(self.llm_out):
|
||||
last_citation_end = 0
|
||||
length_to_add = 0
|
||||
while len(citations_found) > 0:
|
||||
citation = citations_found.pop(0)
|
||||
numerical_value = int(
|
||||
next(group for group in citation.groups() if group is not None)
|
||||
)
|
||||
|
||||
if 1 <= numerical_value <= self.max_citation_num:
|
||||
context_llm_doc = self.context_docs[numerical_value - 1]
|
||||
real_citation_num = self.order_mapping[context_llm_doc.document_id]
|
||||
|
||||
if real_citation_num not in self.citation_order:
|
||||
self.citation_order.append(real_citation_num)
|
||||
|
||||
target_citation_num = (
|
||||
self.citation_order.index(real_citation_num) + 1
|
||||
)
|
||||
|
||||
# Skip consecutive citations of the same work
|
||||
if target_citation_num in self.current_citations:
|
||||
start, end = citation.span()
|
||||
real_start = length_to_add + start
|
||||
diff = end - start
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: length_to_add + start]
|
||||
+ self.curr_segment[real_start + diff :]
|
||||
)
|
||||
length_to_add -= diff
|
||||
continue
|
||||
|
||||
# Handle edge case where LLM outputs citation itself
|
||||
if self.curr_segment.startswith("[["):
|
||||
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
|
||||
if match:
|
||||
try:
|
||||
doc_id = int(match.group(1))
|
||||
context_llm_doc = self.context_docs[doc_id - 1]
|
||||
yield CitationInfo(
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Manual LLM citation didn't properly cite documents {e}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Manual LLM citation wasn't able to close brackets"
|
||||
)
|
||||
continue
|
||||
|
||||
link = context_llm_doc.link
|
||||
|
||||
self.past_cite_count = len(self.llm_out)
|
||||
self.current_citations.append(target_citation_num)
|
||||
|
||||
if target_citation_num not in self.cited_inds:
|
||||
self.cited_inds.add(target_citation_num)
|
||||
yield CitationInfo(
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
start, end = citation.span()
|
||||
if link:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{target_citation_num}]]({link})"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
else:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{target_citation_num}]]()"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
|
||||
last_citation_end = end + length_to_add
|
||||
|
||||
if last_citation_end > 0:
|
||||
result += self.curr_segment[:last_citation_end]
|
||||
self.curr_segment = self.curr_segment[last_citation_end:]
|
||||
|
||||
if not possible_citation_found:
|
||||
result += self.curr_segment
|
||||
self.curr_segment = ""
|
||||
|
||||
if result:
|
||||
yield DanswerAnswerPiece(answer_piece=result)
|
||||
@@ -1,207 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
|
||||
from danswer.chat.models import ResponsePart
|
||||
from danswer.chat.prompt_builder.build import LLMCall
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.message import build_tool_message
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool_runner import (
|
||||
check_which_tools_should_run_for_non_tool_calling_llm,
|
||||
)
|
||||
from danswer.tools.tool_runner import ToolRunner
|
||||
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ToolResponseHandler:
|
||||
def __init__(self, tools: list[Tool]):
|
||||
self.tools = tools
|
||||
|
||||
self.tool_call_chunk: AIMessageChunk | None = None
|
||||
self.tool_call_requests: list[ToolCall] = []
|
||||
|
||||
self.tool_runner: ToolRunner | None = None
|
||||
self.tool_call_summary: ToolCallSummary | None = None
|
||||
|
||||
self.tool_kickoff: ToolCallKickoff | None = None
|
||||
self.tool_responses: list[ToolResponse] = []
|
||||
self.tool_final_result: ToolCallFinalResult | None = None
|
||||
|
||||
@classmethod
|
||||
def get_tool_call_for_non_tool_calling_llm(
|
||||
cls, llm_call: LLMCall, llm: LLM
|
||||
) -> tuple[Tool, dict] | None:
|
||||
if llm_call.force_use_tool.force_use:
|
||||
# if we are forcing a tool, we don't need to check which tools to run
|
||||
tool = next(
|
||||
(
|
||||
t
|
||||
for t in llm_call.tools
|
||||
if t.name == llm_call.force_use_tool.tool_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not tool:
|
||||
raise RuntimeError(
|
||||
f"Tool '{llm_call.force_use_tool.tool_name}' not found"
|
||||
)
|
||||
|
||||
tool_args = (
|
||||
llm_call.force_use_tool.args
|
||||
if llm_call.force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=llm_call.prompt_builder.raw_user_message,
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
force_run=True,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise RuntimeError(f"Tool '{tool.name}' did not return args")
|
||||
|
||||
return (tool, tool_args)
|
||||
else:
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools=llm_call.tools,
|
||||
query=llm_call.prompt_builder.raw_user_message,
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
available_tools_and_args = [
|
||||
(llm_call.tools[ind], args)
|
||||
for ind, args in enumerate(tool_options)
|
||||
if args is not None
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
|
||||
)
|
||||
|
||||
chosen_tool_and_args = (
|
||||
select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args=available_tools_and_args,
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
query=llm_call.prompt_builder.raw_user_message,
|
||||
llm=llm,
|
||||
)
|
||||
if available_tools_and_args
|
||||
else None
|
||||
)
|
||||
|
||||
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
|
||||
return chosen_tool_and_args
|
||||
|
||||
def _handle_tool_call(self) -> Generator[ResponsePart, None, None]:
|
||||
if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls:
|
||||
return
|
||||
|
||||
self.tool_call_requests = self.tool_call_chunk.tool_calls
|
||||
|
||||
selected_tool: Tool | None = None
|
||||
selected_tool_call_request: ToolCall | None = None
|
||||
for tool_call_request in self.tool_call_requests:
|
||||
known_tools_by_name = [
|
||||
tool for tool in self.tools if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
if not known_tools_by_name:
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"self.tools: {self.tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
selected_tool = known_tools_by_name[0]
|
||||
selected_tool_call_request = tool_call_request
|
||||
|
||||
if selected_tool and selected_tool_call_request:
|
||||
break
|
||||
|
||||
if not selected_tool or not selected_tool_call_request:
|
||||
return
|
||||
|
||||
logger.info(f"Selected tool: {selected_tool.name}")
|
||||
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
|
||||
self.tool_runner = ToolRunner(selected_tool, selected_tool_call_request["args"])
|
||||
self.tool_kickoff = self.tool_runner.kickoff()
|
||||
yield self.tool_kickoff
|
||||
|
||||
for response in self.tool_runner.tool_responses():
|
||||
self.tool_responses.append(response)
|
||||
yield response
|
||||
|
||||
self.tool_final_result = self.tool_runner.tool_final_result()
|
||||
yield self.tool_final_result
|
||||
|
||||
self.tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=self.tool_call_chunk,
|
||||
tool_call_result=build_tool_message(
|
||||
selected_tool_call_request, self.tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
if response_item is None:
|
||||
yield from self._handle_tool_call()
|
||||
|
||||
if isinstance(response_item, AIMessageChunk) and (
|
||||
response_item.tool_call_chunks or response_item.tool_calls
|
||||
):
|
||||
if self.tool_call_chunk is None:
|
||||
self.tool_call_chunk = response_item
|
||||
else:
|
||||
self.tool_call_chunk += response_item # type: ignore
|
||||
|
||||
return
|
||||
|
||||
def next_llm_call(self, current_llm_call: LLMCall) -> LLMCall | None:
|
||||
if (
|
||||
self.tool_runner is None
|
||||
or self.tool_call_summary is None
|
||||
or self.tool_kickoff is None
|
||||
or self.tool_final_result is None
|
||||
):
|
||||
return None
|
||||
|
||||
tool_runner = self.tool_runner
|
||||
new_prompt_builder = tool_runner.tool.build_next_prompt(
|
||||
prompt_builder=current_llm_call.prompt_builder,
|
||||
tool_call_summary=self.tool_call_summary,
|
||||
tool_responses=self.tool_responses,
|
||||
using_tool_calling_llm=current_llm_call.using_tool_calling_llm,
|
||||
)
|
||||
return LLMCall(
|
||||
prompt_builder=new_prompt_builder,
|
||||
tools=[], # for now, only allow one tool call per response
|
||||
force_use_tool=ForceUseTool(
|
||||
force_use=False,
|
||||
tool_name="",
|
||||
args=None,
|
||||
),
|
||||
files=current_llm_call.files,
|
||||
using_tool_calling_llm=current_llm_call.using_tool_calling_llm,
|
||||
tool_call_info=[
|
||||
self.tool_kickoff,
|
||||
*self.tool_responses,
|
||||
self.tool_final_result,
|
||||
],
|
||||
)
|
||||
115
backend/danswer/chat/tools.py
Normal file
115
backend/danswer/chat/tools.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from typing_extensions import TypedDict # noreorder
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.prompts.chat_tools import DANSWER_TOOL_DESCRIPTION
|
||||
from danswer.prompts.chat_tools import DANSWER_TOOL_NAME
|
||||
from danswer.prompts.chat_tools import TOOL_FOLLOWUP
|
||||
from danswer.prompts.chat_tools import TOOL_LESS_FOLLOWUP
|
||||
from danswer.prompts.chat_tools import TOOL_LESS_PROMPT
|
||||
from danswer.prompts.chat_tools import TOOL_TEMPLATE
|
||||
from danswer.prompts.chat_tools import USER_INPUT
|
||||
|
||||
|
||||
class ToolInfo(TypedDict):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
|
||||
class DanswerChatModelOut(BaseModel):
|
||||
model_raw: str
|
||||
action: str
|
||||
action_input: str
|
||||
|
||||
|
||||
def call_tool(
|
||||
model_actions: DanswerChatModelOut,
|
||||
) -> str:
|
||||
raise NotImplementedError("There are no additional tool integrations right now")
|
||||
|
||||
|
||||
def form_user_prompt_text(
|
||||
query: str,
|
||||
tool_text: str | None,
|
||||
hint_text: str | None,
|
||||
user_input_prompt: str = USER_INPUT,
|
||||
tool_less_prompt: str = TOOL_LESS_PROMPT,
|
||||
) -> str:
|
||||
user_prompt = tool_text or tool_less_prompt
|
||||
|
||||
user_prompt += user_input_prompt.format(user_input=query)
|
||||
|
||||
if hint_text:
|
||||
if user_prompt[-1] != "\n":
|
||||
user_prompt += "\n"
|
||||
user_prompt += "\nHint: " + hint_text
|
||||
|
||||
return user_prompt.strip()
|
||||
|
||||
|
||||
def form_tool_section_text(
|
||||
tools: list[ToolInfo] | None, retrieval_enabled: bool, template: str = TOOL_TEMPLATE
|
||||
) -> str | None:
|
||||
if not tools and not retrieval_enabled:
|
||||
return None
|
||||
|
||||
if retrieval_enabled and tools:
|
||||
tools.append(
|
||||
{"name": DANSWER_TOOL_NAME, "description": DANSWER_TOOL_DESCRIPTION}
|
||||
)
|
||||
|
||||
tools_intro = []
|
||||
if tools:
|
||||
num_tools = len(tools)
|
||||
for tool in tools:
|
||||
description_formatted = tool["description"].replace("\n", " ")
|
||||
tools_intro.append(f"> {tool['name']}: {description_formatted}")
|
||||
|
||||
prefix = "Must be one of " if num_tools > 1 else "Must be "
|
||||
|
||||
tools_intro_text = "\n".join(tools_intro)
|
||||
tool_names_text = prefix + ", ".join([tool["name"] for tool in tools])
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
return template.format(
|
||||
tool_overviews=tools_intro_text, tool_names=tool_names_text
|
||||
).strip()
|
||||
|
||||
|
||||
def form_tool_followup_text(
|
||||
tool_output: str,
|
||||
query: str,
|
||||
hint_text: str | None,
|
||||
tool_followup_prompt: str = TOOL_FOLLOWUP,
|
||||
ignore_hint: bool = False,
|
||||
) -> str:
|
||||
# If multi-line query, it likely confuses the model more than helps
|
||||
if "\n" not in query:
|
||||
optional_reminder = f"\nAs a reminder, my query was: {query}\n"
|
||||
else:
|
||||
optional_reminder = ""
|
||||
|
||||
if not ignore_hint and hint_text:
|
||||
hint_text_spaced = f"\nHint: {hint_text}\n"
|
||||
else:
|
||||
hint_text_spaced = ""
|
||||
|
||||
return tool_followup_prompt.format(
|
||||
tool_output=tool_output,
|
||||
optional_reminder=optional_reminder,
|
||||
hint=hint_text_spaced,
|
||||
).strip()
|
||||
|
||||
|
||||
def form_tool_less_followup_text(
|
||||
tool_output: str,
|
||||
query: str,
|
||||
hint_text: str | None,
|
||||
tool_followup_prompt: str = TOOL_LESS_FOLLOWUP,
|
||||
) -> str:
|
||||
hint = f"Hint: {hint_text}" if hint_text else ""
|
||||
return tool_followup_prompt.format(
|
||||
context_str=tool_output, user_query=query, hint_text=hint
|
||||
).strip()
|
||||
@@ -43,6 +43,9 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
|
||||
AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower())
|
||||
DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
|
||||
|
||||
# Necessary for cloud integration tests
|
||||
DISABLE_VERIFICATION = os.environ.get("DISABLE_VERIFICATION", "").lower() == "true"
|
||||
|
||||
# Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive
|
||||
# information. This provides an extra layer of security on top of Postgres access controls
|
||||
# and is available in Danswer EE
|
||||
@@ -81,14 +84,7 @@ OAUTH_CLIENT_SECRET = (
|
||||
or ""
|
||||
)
|
||||
|
||||
# for future OAuth connector support
|
||||
# OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "")
|
||||
# OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "")
|
||||
# OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "")
|
||||
# OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "")
|
||||
|
||||
USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "")
|
||||
|
||||
# for basic auth
|
||||
REQUIRE_EMAIL_VERIFICATION = (
|
||||
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
|
||||
@@ -122,8 +118,6 @@ VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
VESPA_CONFIG_SERVER_HOST = os.environ.get("VESPA_CONFIG_SERVER_HOST") or VESPA_HOST
|
||||
VESPA_PORT = os.environ.get("VESPA_PORT") or "8081"
|
||||
VESPA_TENANT_PORT = os.environ.get("VESPA_TENANT_PORT") or "19071"
|
||||
# the number of times to try and connect to vespa on startup before giving up
|
||||
VESPA_NUM_ATTEMPTS_ON_STARTUP = int(os.environ.get("NUM_RETRIES_ON_STARTUP") or 10)
|
||||
|
||||
VESPA_CLOUD_URL = os.environ.get("VESPA_CLOUD_URL", "")
|
||||
|
||||
@@ -169,17 +163,6 @@ try:
|
||||
except ValueError:
|
||||
POSTGRES_POOL_RECYCLE = POSTGRES_POOL_RECYCLE_DEFAULT
|
||||
|
||||
# Experimental setting to control idle transactions
|
||||
POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT = 0 # milliseconds
|
||||
try:
|
||||
POSTGRES_IDLE_SESSIONS_TIMEOUT = int(
|
||||
os.environ.get(
|
||||
"POSTGRES_IDLE_SESSIONS_TIMEOUT", POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
POSTGRES_IDLE_SESSIONS_TIMEOUT = POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT
|
||||
|
||||
REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true"
|
||||
REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
|
||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
||||
@@ -240,7 +223,7 @@ except ValueError:
|
||||
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT
|
||||
)
|
||||
|
||||
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 3
|
||||
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 1
|
||||
try:
|
||||
env_value = os.environ.get("CELERY_WORKER_INDEXING_CONCURRENCY")
|
||||
if not env_value:
|
||||
@@ -268,6 +251,9 @@ ENABLED_CONNECTOR_TYPES = os.environ.get("ENABLED_CONNECTOR_TYPES") or ""
|
||||
# for some connectors
|
||||
ENABLE_EXPENSIVE_EXPERT_CALLS = False
|
||||
|
||||
GOOGLE_DRIVE_INCLUDE_SHARED = False
|
||||
GOOGLE_DRIVE_FOLLOW_SHORTCUTS = False
|
||||
GOOGLE_DRIVE_ONLY_ORG_PUBLIC = False
|
||||
|
||||
# TODO these should be available for frontend configuration, via advanced options expandable
|
||||
WEB_CONNECTOR_IGNORED_CLASSES = os.environ.get(
|
||||
@@ -314,22 +300,6 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000)
|
||||
)
|
||||
|
||||
# Due to breakages in the confluence API, the timezone offset must be specified client side
|
||||
# to match the user's specified timezone.
|
||||
|
||||
# The current state of affairs:
|
||||
# CQL queries are parsed in the user's timezone and cannot be specified in UTC
|
||||
# no API retrieves the user's timezone
|
||||
# All data is returned in UTC, so we can't derive the user's timezone from that
|
||||
|
||||
# https://community.developer.atlassian.com/t/confluence-cloud-time-zone-get-via-rest-api/35954/16
|
||||
# https://jira.atlassian.com/browse/CONFCLOUD-69670
|
||||
|
||||
# enter as a floating point offset from UTC in hours (-24 < val < 24)
|
||||
# this will be applied globally, so it probably makes sense to transition this to per
|
||||
# connector as some point.
|
||||
CONFLUENCE_TIMEZONE_OFFSET = float(os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", 0.0))
|
||||
|
||||
JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
ignored_tag
|
||||
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
|
||||
@@ -444,9 +414,6 @@ LOG_ALL_MODEL_INTERACTIONS = (
|
||||
LOG_DANSWER_MODEL_INTERACTIONS = (
|
||||
os.environ.get("LOG_DANSWER_MODEL_INTERACTIONS", "").lower() == "true"
|
||||
)
|
||||
LOG_INDIVIDUAL_MODEL_TOKENS = (
|
||||
os.environ.get("LOG_INDIVIDUAL_MODEL_TOKENS", "").lower() == "true"
|
||||
)
|
||||
# If set to `true` will enable additional logs about Vespa query performance
|
||||
# (time spent on finding the right docs + time spent fetching summaries from disk)
|
||||
LOG_VESPA_TIMING_INFORMATION = (
|
||||
@@ -514,20 +481,3 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
|
||||
|
||||
# JWT configuration
|
||||
JWT_ALGORITHM = "HS256"
|
||||
|
||||
|
||||
#####
|
||||
# API Key Configs
|
||||
#####
|
||||
# refers to the rounds described here: https://passlib.readthedocs.io/en/stable/lib/passlib.hash.sha256_crypt.html
|
||||
_API_KEY_HASH_ROUNDS_RAW = os.environ.get("API_KEY_HASH_ROUNDS")
|
||||
API_KEY_HASH_ROUNDS = (
|
||||
int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None
|
||||
)
|
||||
|
||||
|
||||
POD_NAME = os.environ.get("POD_NAME")
|
||||
POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
|
||||
|
||||
|
||||
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import os
|
||||
|
||||
|
||||
PROMPTS_YAML = "./danswer/seeding/prompts.yaml"
|
||||
PERSONAS_YAML = "./danswer/seeding/personas.yaml"
|
||||
INPUT_PROMPT_YAML = "./danswer/seeding/input_prompts.yaml"
|
||||
PROMPTS_YAML = "./danswer/chat/prompts.yaml"
|
||||
PERSONAS_YAML = "./danswer/chat/personas.yaml"
|
||||
INPUT_PROMPT_YAML = "./danswer/chat/input_prompts.yaml"
|
||||
|
||||
NUM_RETURNED_HITS = 50
|
||||
# Used for LLM filtering and reranking
|
||||
@@ -17,6 +17,9 @@ MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
|
||||
# ~3k input, half for docs, half for chat history + prompts
|
||||
CHAT_TARGET_CHUNK_PERCENTAGE = 512 * 3 / 3072
|
||||
|
||||
# For selecting a different LLM question-answering prompt format
|
||||
# Valid values: default, cot, weak
|
||||
QA_PROMPT_OVERRIDE = os.environ.get("QA_PROMPT_OVERRIDE") or None
|
||||
# 1 / (1 + DOC_TIME_DECAY * doc-age-in-years), set to 0 to have no decay
|
||||
# Capped in Vespa at 0.5
|
||||
DOC_TIME_DECAY = float(
|
||||
@@ -24,6 +27,8 @@ DOC_TIME_DECAY = float(
|
||||
)
|
||||
BASE_RECENCY_DECAY = 0.5
|
||||
FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
|
||||
# Currently this next one is not configurable via env
|
||||
DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
|
||||
# For the highest matching base size chunk, how many chunks above and below do we pull in by default
|
||||
# Note this is not in any of the deployment configs yet
|
||||
# Currently only applies to search flow not chat
|
||||
|
||||
@@ -31,8 +31,6 @@ DISABLED_GEN_AI_MSG = (
|
||||
"You can still use Danswer as a search engine."
|
||||
)
|
||||
|
||||
DEFAULT_PERSONA_ID = 0
|
||||
|
||||
# Postgres connection constants for application_name
|
||||
POSTGRES_WEB_APP_NAME = "web"
|
||||
POSTGRES_INDEXER_APP_NAME = "indexer"
|
||||
@@ -62,6 +60,7 @@ KV_GMAIL_CRED_KEY = "gmail_app_credential"
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY = "gmail_service_account_key"
|
||||
KV_GOOGLE_DRIVE_CRED_KEY = "google_drive_app_credential"
|
||||
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
|
||||
KV_SLACK_BOT_TOKENS_CONFIG_KEY = "slack_bot_tokens_config_key"
|
||||
KV_GEN_AI_KEY_CHECK_TIME = "genai_api_key_last_check_time"
|
||||
KV_SETTINGS_KEY = "danswer_settings"
|
||||
KV_CUSTOMER_UUID_KEY = "customer_uuid"
|
||||
@@ -75,16 +74,12 @@ CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
|
||||
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_INDEXING_LOCK_TIMEOUT = 3 * 60 * 60 # 60 min
|
||||
CELERY_INDEXING_LOCK_TIMEOUT = 60 * 60 # 60 min
|
||||
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
|
||||
|
||||
|
||||
@@ -130,8 +125,6 @@ class DocumentSource(str, Enum):
|
||||
OCI_STORAGE = "oci_storage"
|
||||
XENFORO = "xenforo"
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
FRESHDESK = "freshdesk"
|
||||
FIREFLIES = "fireflies"
|
||||
|
||||
|
||||
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
|
||||
@@ -214,17 +207,9 @@ class PostgresAdvisoryLocks(Enum):
|
||||
|
||||
|
||||
class DanswerCeleryQueues:
|
||||
# Light queue
|
||||
VESPA_METADATA_SYNC = "vespa_metadata_sync"
|
||||
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
|
||||
CONNECTOR_DELETION = "connector_deletion"
|
||||
|
||||
# Heavy queue
|
||||
CONNECTOR_PRUNING = "connector_pruning"
|
||||
CONNECTOR_DOC_PERMISSIONS_SYNC = "connector_doc_permissions_sync"
|
||||
CONNECTOR_EXTERNAL_GROUP_SYNC = "connector_external_group_sync"
|
||||
|
||||
# Indexing queue
|
||||
CONNECTOR_INDEXING = "connector_indexing"
|
||||
|
||||
|
||||
@@ -234,24 +219,11 @@ class DanswerRedisLocks:
|
||||
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
|
||||
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
|
||||
CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat"
|
||||
CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK = (
|
||||
"da_lock:check_connector_doc_permissions_sync_beat"
|
||||
)
|
||||
CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK = (
|
||||
"da_lock:check_connector_external_group_sync_beat"
|
||||
)
|
||||
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
|
||||
|
||||
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
|
||||
"da_lock:connector_doc_permissions_sync"
|
||||
)
|
||||
CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX = "da_lock:connector_external_group_sync"
|
||||
PRUNING_LOCK_PREFIX = "da_lock:pruning"
|
||||
INDEXING_METADATA_PREFIX = "da_metadata:indexing"
|
||||
|
||||
SLACK_BOT_LOCK = "da_lock:slack_bot"
|
||||
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
|
||||
|
||||
|
||||
class DanswerCeleryPriority(int, Enum):
|
||||
HIGHEST = 0
|
||||
@@ -261,32 +233,6 @@ class DanswerCeleryPriority(int, Enum):
|
||||
LOWEST = auto()
|
||||
|
||||
|
||||
class DanswerCeleryTask:
|
||||
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
|
||||
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
||||
CHECK_FOR_INDEXING = "check_for_indexing"
|
||||
CHECK_FOR_PRUNING = "check_for_pruning"
|
||||
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
"connector_permission_sync_generator_task"
|
||||
)
|
||||
UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK = (
|
||||
"update_external_document_permissions_task"
|
||||
)
|
||||
CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK = (
|
||||
"connector_external_group_sync_generator_task"
|
||||
)
|
||||
CONNECTOR_INDEXING_PROXY_TASK = "connector_indexing_proxy_task"
|
||||
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
|
||||
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
|
||||
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
|
||||
AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task"
|
||||
|
||||
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3
|
||||
|
||||
@@ -4,8 +4,11 @@ import os
|
||||
# Danswer Slack Bot Configs
|
||||
#####
|
||||
DANSWER_BOT_NUM_RETRIES = int(os.environ.get("DANSWER_BOT_NUM_RETRIES", "5"))
|
||||
DANSWER_BOT_ANSWER_GENERATION_TIMEOUT = int(
|
||||
os.environ.get("DANSWER_BOT_ANSWER_GENERATION_TIMEOUT", "90")
|
||||
)
|
||||
# How much of the available input context can be used for thread context
|
||||
MAX_THREAD_CONTEXT_PERCENTAGE = 512 * 2 / 3072
|
||||
DANSWER_BOT_TARGET_CHUNK_PERCENTAGE = 512 * 2 / 3072
|
||||
# Number of docs to display in "Reference Documents"
|
||||
DANSWER_BOT_NUM_DOCS_TO_DISPLAY = int(
|
||||
os.environ.get("DANSWER_BOT_NUM_DOCS_TO_DISPLAY", "5")
|
||||
@@ -44,6 +47,17 @@ DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
|
||||
DANSWER_BOT_RESPOND_EVERY_CHANNEL = (
|
||||
os.environ.get("DANSWER_BOT_RESPOND_EVERY_CHANNEL", "").lower() == "true"
|
||||
)
|
||||
# Add a second LLM call post Answer to verify if the Answer is valid
|
||||
# Throws out answers that don't directly or fully answer the user query
|
||||
# This is the default for all DanswerBot channels unless the channel is configured individually
|
||||
# Set/unset by "Hide Non Answers"
|
||||
ENABLE_DANSWERBOT_REFLEXION = (
|
||||
os.environ.get("ENABLE_DANSWERBOT_REFLEXION", "").lower() == "true"
|
||||
)
|
||||
# Currently not support chain of thought, probably will add back later
|
||||
DANSWER_BOT_DISABLE_COT = True
|
||||
# if set, will default DanswerBot to use quotes and reference documents
|
||||
DANSWER_BOT_USE_QUOTES = os.environ.get("DANSWER_BOT_USE_QUOTES", "").lower() == "true"
|
||||
|
||||
# Maximum Questions Per Minute, Default Uncapped
|
||||
DANSWER_BOT_MAX_QPM = int(os.environ.get("DANSWER_BOT_MAX_QPM") or 0) or None
|
||||
|
||||
@@ -70,9 +70,7 @@ GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int(
|
||||
)
|
||||
|
||||
# Typically, GenAI models nowadays are at least 4K tokens
|
||||
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = int(
|
||||
os.environ.get("GEN_AI_MODEL_FALLBACK_MAX_TOKENS") or 4096
|
||||
)
|
||||
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096
|
||||
|
||||
# Number of tokens from chat history to include at maximum
|
||||
# 3000 should be enough context regardless of use, no need to include as much as possible
|
||||
@@ -121,14 +119,3 @@ if _LITELLM_PASS_THROUGH_HEADERS_RAW:
|
||||
logger.error(
|
||||
"Failed to parse LITELLM_PASS_THROUGH_HEADERS, must be a valid JSON object"
|
||||
)
|
||||
|
||||
|
||||
# if specified, will merge the specified JSON with the existing body of the
|
||||
# request before sending it to the LLM
|
||||
LITELLM_EXTRA_BODY: dict | None = None
|
||||
_LITELLM_EXTRA_BODY_RAW = os.environ.get("LITELLM_EXTRA_BODY")
|
||||
if _LITELLM_EXTRA_BODY_RAW:
|
||||
try:
|
||||
LITELLM_EXTRA_BODY = json.loads(_LITELLM_EXTRA_BODY_RAW)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -2,8 +2,6 @@ import json
|
||||
import os
|
||||
|
||||
|
||||
IMAGE_GENERATION_OUTPUT_FORMAT = os.environ.get("IMAGE_GENERATION_OUTPUT_FORMAT", "url")
|
||||
|
||||
# if specified, will pass through request headers to the call to API calls made by custom tools
|
||||
CUSTOM_TOOL_PASS_THROUGH_HEADERS: list[str] | None = None
|
||||
_CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get(
|
||||
|
||||
@@ -11,16 +11,11 @@ Connectors come in 3 different flows:
|
||||
- Load Connector:
|
||||
- Bulk indexes documents to reflect a point in time. This type of connector generally works by either pulling all
|
||||
documents via a connector's API or loads the documents from some sort of a dump file.
|
||||
- Poll Connector:
|
||||
- Poll connector:
|
||||
- Incrementally updates documents based on a provided time range. It is used by the background job to pull the latest
|
||||
changes and additions since the last round of polling. This connector helps keep the document index up to date
|
||||
without needing to fetch/embed/index every document which would be too slow to do frequently on large sets of
|
||||
documents.
|
||||
- Slim Connector:
|
||||
- This connector should be a lighter weight method of checking all documents in the source to see if they still exist.
|
||||
- This connector should be identical to the Poll or Load Connector except that it only fetches the IDs of the documents, not the documents themselves.
|
||||
- This is used by our pruning job which removes old documents from the index.
|
||||
- The optional start and end datetimes can be ignored.
|
||||
- Event Based connectors:
|
||||
- Connectors that listen to events and update documents accordingly.
|
||||
- Currently not used by the background job, this exists for future design purposes.
|
||||
@@ -31,14 +26,8 @@ Refer to [interfaces.py](https://github.com/danswer-ai/danswer/blob/main/backend
|
||||
and this first contributor created Pull Request for a new connector (Shoutout to Dan Brown):
|
||||
[Reference Pull Request](https://github.com/danswer-ai/danswer/pull/139)
|
||||
|
||||
For implementing a Slim Connector, refer to the comments in this PR:
|
||||
[Slim Connector PR](https://github.com/danswer-ai/danswer/pull/3303/files)
|
||||
|
||||
All new connectors should have tests added to the `backend/tests/daily/connectors` directory. Refer to the above PR for an example of adding tests for a new connector.
|
||||
|
||||
|
||||
#### Implementing the new Connector
|
||||
The connector must subclass one or more of LoadConnector, PollConnector, SlimConnector, or EventConnector.
|
||||
The connector must subclass one or more of LoadConnector, PollConnector, or EventConnector.
|
||||
|
||||
The `__init__` should take arguments for configuring what documents the connector will and where it finds those
|
||||
documents. For example, if you have a wiki site, it may include the configuration for the team, topic, folder, etc. of
|
||||
|
||||
@@ -5,9 +5,9 @@ from io import BytesIO
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
import boto3 # type: ignore
|
||||
from botocore.client import Config # type: ignore
|
||||
from mypy_boto3_s3 import S3Client # type: ignore
|
||||
import boto3
|
||||
from botocore.client import Config
|
||||
from mypy_boto3_s3 import S3Client
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import BlobType
|
||||
|
||||
@@ -1,26 +1,22 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.confluence.onyx_confluence import build_confluence_client
|
||||
from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from danswer.connectors.confluence.utils import attachment_to_content
|
||||
from danswer.connectors.confluence.utils import build_confluence_client
|
||||
from danswer.connectors.confluence.utils import build_confluence_document_id
|
||||
from danswer.connectors.confluence.utils import datetime_from_string
|
||||
from danswer.connectors.confluence.utils import extract_text_from_confluence_html
|
||||
from danswer.connectors.confluence.utils import validate_attachment_filetype
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.interfaces import SlimConnector
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
@@ -54,8 +50,6 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
|
||||
"restrictions.read.restrictions.group",
|
||||
]
|
||||
|
||||
_SLIM_DOC_BATCH_SIZE = 5000
|
||||
|
||||
|
||||
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
@@ -72,11 +66,10 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# skip it. This is generally used to avoid indexing extra sensitive
|
||||
# pages.
|
||||
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
|
||||
timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self._confluence_client: OnyxConfluence | None = None
|
||||
self.confluence_client: OnyxConfluence | None = None
|
||||
self.is_cloud = is_cloud
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
@@ -87,62 +80,53 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
if cql_query:
|
||||
# if a cql_query is provided, we will use it to fetch the pages
|
||||
cql_page_query = cql_query
|
||||
elif space:
|
||||
# if no cql_query is provided, we will use the space to fetch the pages
|
||||
cql_page_query += f" and space='{quote(space)}'"
|
||||
elif page_id:
|
||||
# if a cql_query is not provided, we will use the page_id to fetch the page
|
||||
if index_recursively:
|
||||
cql_page_query += f" and ancestor='{page_id}'"
|
||||
else:
|
||||
# if neither a space nor a cql_query is provided, we will use the page_id to fetch the page
|
||||
cql_page_query += f" and id='{page_id}'"
|
||||
elif space:
|
||||
# if no cql_query or page_id is provided, we will use the space to fetch the pages
|
||||
cql_page_query += f" and space='{quote(space)}'"
|
||||
|
||||
self.cql_page_query = cql_page_query
|
||||
self.cql_time_filter = ""
|
||||
|
||||
self.cql_label_filter = ""
|
||||
self.cql_time_filter = ""
|
||||
if labels_to_skip:
|
||||
labels_to_skip = list(set(labels_to_skip))
|
||||
comma_separated_labels = ",".join(
|
||||
f"'{quote(label)}'" for label in labels_to_skip
|
||||
)
|
||||
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
|
||||
|
||||
self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset))
|
||||
|
||||
@property
|
||||
def confluence_client(self) -> OnyxConfluence:
|
||||
if self._confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
return self._confluence_client
|
||||
comma_separated_labels = ",".join(labels_to_skip)
|
||||
self.cql_label_filter = f"&label not in ({comma_separated_labels})"
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
|
||||
# for a list of other hidden constructor args
|
||||
self._confluence_client = build_confluence_client(
|
||||
credentials=credentials,
|
||||
self.confluence_client = build_confluence_client(
|
||||
credentials_json=credentials,
|
||||
is_cloud=self.is_cloud,
|
||||
wiki_base=self.wiki_base,
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_comment_string_for_page_id(self, page_id: str) -> str:
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
comment_string = ""
|
||||
|
||||
comment_cql = f"type=comment and container='{page_id}'"
|
||||
comment_cql += self.cql_label_filter
|
||||
|
||||
expand = ",".join(_COMMENT_EXPANSION_FIELDS)
|
||||
for comment in self.confluence_client.paginated_cql_retrieval(
|
||||
for comments in self.confluence_client.paginated_cql_page_retrieval(
|
||||
cql=comment_cql,
|
||||
expand=expand,
|
||||
):
|
||||
comment_string += "\nComment:\n"
|
||||
comment_string += extract_text_from_confluence_html(
|
||||
confluence_client=self.confluence_client,
|
||||
confluence_object=comment,
|
||||
fetched_titles=set(),
|
||||
)
|
||||
for comment in comments:
|
||||
comment_string += "\nComment:\n"
|
||||
comment_string += extract_text_from_confluence_html(
|
||||
confluence_client=self.confluence_client, confluence_object=comment
|
||||
)
|
||||
|
||||
return comment_string
|
||||
|
||||
@@ -154,28 +138,28 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
If its a page, it extracts the text, adds the comments for the document text.
|
||||
If its an attachment, it just downloads the attachment and converts that into a document.
|
||||
"""
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
# The url and the id are the same
|
||||
object_url = build_confluence_document_id(
|
||||
self.wiki_base, confluence_object["_links"]["webui"], self.is_cloud
|
||||
self.wiki_base, confluence_object["_links"]["webui"]
|
||||
)
|
||||
|
||||
object_text = None
|
||||
# Extract text from page
|
||||
if confluence_object["type"] == "page":
|
||||
object_text = extract_text_from_confluence_html(
|
||||
confluence_client=self.confluence_client,
|
||||
confluence_object=confluence_object,
|
||||
fetched_titles={confluence_object.get("title", "")},
|
||||
self.confluence_client, confluence_object
|
||||
)
|
||||
# Add comments to text
|
||||
object_text += self._get_comment_string_for_page_id(confluence_object["id"])
|
||||
elif confluence_object["type"] == "attachment":
|
||||
object_text = attachment_to_content(
|
||||
confluence_client=self.confluence_client, attachment=confluence_object
|
||||
self.confluence_client, confluence_object
|
||||
)
|
||||
|
||||
if object_text is None:
|
||||
# This only happens for attachments that are not parseable
|
||||
return None
|
||||
|
||||
# Get space name
|
||||
@@ -206,41 +190,44 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
|
||||
def _fetch_document_batches(self) -> GenerateDocumentsOutput:
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
confluence_page_ids: list[str] = []
|
||||
|
||||
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
|
||||
logger.debug(f"page_query: {page_query}")
|
||||
# Fetch pages as Documents
|
||||
for page in self.confluence_client.paginated_cql_retrieval(
|
||||
for page_batch in self.confluence_client.paginated_cql_page_retrieval(
|
||||
cql=page_query,
|
||||
expand=",".join(_PAGE_EXPANSION_FIELDS),
|
||||
limit=self.batch_size,
|
||||
):
|
||||
logger.debug(f"_fetch_document_batches: {page['id']}")
|
||||
confluence_page_ids.append(page["id"])
|
||||
doc = self._convert_object_to_document(page)
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
for page in page_batch:
|
||||
confluence_page_ids.append(page["id"])
|
||||
doc = self._convert_object_to_document(page)
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
# Fetch attachments as Documents
|
||||
for confluence_page_id in confluence_page_ids:
|
||||
attachment_cql = f"type=attachment and container='{confluence_page_id}'"
|
||||
attachment_cql += self.cql_label_filter
|
||||
# TODO: maybe should add time filter as well?
|
||||
for attachment in self.confluence_client.paginated_cql_retrieval(
|
||||
for attachments in self.confluence_client.paginated_cql_page_retrieval(
|
||||
cql=attachment_cql,
|
||||
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
||||
):
|
||||
doc = self._convert_object_to_document(attachment)
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
for attachment in attachments:
|
||||
doc = self._convert_object_to_document(attachment)
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
@@ -250,84 +237,59 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
|
||||
# Add time filters
|
||||
formatted_start_time = datetime.fromtimestamp(start, tz=self.timezone).strftime(
|
||||
formatted_start_time = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
self.cql_time_filter = f" and lastmodified >= '{formatted_start_time}'"
|
||||
self.cql_time_filter += f" and lastmodified <= '{formatted_end_time}'"
|
||||
return self._fetch_document_batches()
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
|
||||
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
|
||||
|
||||
page_query = self.cql_page_query + self.cql_label_filter
|
||||
for page in self.confluence_client.cql_paginate_all_expansions(
|
||||
for pages in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=page_query,
|
||||
expand=restrictions_expand,
|
||||
limit=_SLIM_DOC_BATCH_SIZE,
|
||||
):
|
||||
# If the page has restrictions, add them to the perm_sync_data
|
||||
# These will be used by doc_sync.py to sync permissions
|
||||
page_restrictions = page.get("restrictions")
|
||||
page_space_key = page.get("space", {}).get("key")
|
||||
page_perm_sync_data = {
|
||||
"restrictions": page_restrictions or {},
|
||||
"space_key": page_space_key,
|
||||
}
|
||||
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=build_confluence_document_id(
|
||||
self.wiki_base,
|
||||
page["_links"]["webui"],
|
||||
self.is_cloud,
|
||||
),
|
||||
perm_sync_data=page_perm_sync_data,
|
||||
)
|
||||
)
|
||||
attachment_cql = f"type=attachment and container='{page['id']}'"
|
||||
attachment_cql += self.cql_label_filter
|
||||
for attachment in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=attachment_cql,
|
||||
expand=restrictions_expand,
|
||||
limit=_SLIM_DOC_BATCH_SIZE,
|
||||
):
|
||||
if not validate_attachment_filetype(attachment):
|
||||
continue
|
||||
attachment_restrictions = attachment.get("restrictions")
|
||||
if not attachment_restrictions:
|
||||
attachment_restrictions = page_restrictions
|
||||
|
||||
attachment_space_key = attachment.get("space", {}).get("key")
|
||||
if not attachment_space_key:
|
||||
attachment_space_key = page_space_key
|
||||
|
||||
attachment_perm_sync_data = {
|
||||
"restrictions": attachment_restrictions or {},
|
||||
"space_key": attachment_space_key,
|
||||
for page in pages:
|
||||
# If the page has restrictions, add them to the perm_sync_data
|
||||
# These will be used by doc_sync.py to sync permissions
|
||||
perm_sync_data = {
|
||||
"restrictions": page.get("restrictions", {}),
|
||||
"space_key": page.get("space", {}).get("key"),
|
||||
}
|
||||
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=build_confluence_document_id(
|
||||
self.wiki_base,
|
||||
attachment["_links"]["webui"],
|
||||
self.is_cloud,
|
||||
self.wiki_base, page["_links"]["webui"]
|
||||
),
|
||||
perm_sync_data=attachment_perm_sync_data,
|
||||
perm_sync_data=perm_sync_data,
|
||||
)
|
||||
)
|
||||
if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE:
|
||||
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
|
||||
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
|
||||
|
||||
yield doc_metadata_list
|
||||
attachment_cql = f"type=attachment and container='{page['id']}'"
|
||||
attachment_cql += self.cql_label_filter
|
||||
for attachments in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=attachment_cql,
|
||||
expand=restrictions_expand,
|
||||
):
|
||||
for attachment in attachments:
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=build_confluence_document_id(
|
||||
self.wiki_base, attachment["_links"]["webui"]
|
||||
),
|
||||
perm_sync_data=perm_sync_data,
|
||||
)
|
||||
)
|
||||
yield doc_metadata_list
|
||||
doc_metadata_list = []
|
||||
|
||||
@@ -20,10 +20,6 @@ F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||
|
||||
# https://jira.atlassian.com/browse/CONFCLOUD-76433
|
||||
_PROBLEMATIC_EXPANSIONS = "body.storage.value"
|
||||
_REPLACEMENT_EXPANSIONS = "body.view.value"
|
||||
|
||||
|
||||
class ConfluenceRateLimitError(Exception):
|
||||
pass
|
||||
@@ -84,7 +80,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
MAX_RETRIES = 5
|
||||
|
||||
TIMEOUT = 600
|
||||
TIMEOUT = 3600
|
||||
timeout_at = time.monotonic() + TIMEOUT
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
@@ -99,10 +95,6 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. "
|
||||
f"Retrying in {delay_until} seconds..."
|
||||
)
|
||||
while time.monotonic() < delay_until:
|
||||
# in the future, check a signal here to exit
|
||||
time.sleep(1)
|
||||
@@ -120,7 +112,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
_DEFAULT_PAGINATION_LIMIT = 1000
|
||||
_DEFAULT_PAGINATION_LIMIT = 100
|
||||
|
||||
|
||||
class OnyxConfluence(Confluence):
|
||||
@@ -134,32 +126,6 @@ class OnyxConfluence(Confluence):
|
||||
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
|
||||
self._wrap_methods()
|
||||
|
||||
def get_current_user(self, expand: str | None = None) -> Any:
|
||||
"""
|
||||
Implements a method that isn't in the third party client.
|
||||
|
||||
Get information about the current user
|
||||
:param expand: OPTIONAL expand for get status of user.
|
||||
Possible param is "status". Results are "Active, Deactivated"
|
||||
:return: Returns the user details
|
||||
"""
|
||||
|
||||
from atlassian.errors import ApiPermissionError # type:ignore
|
||||
|
||||
url = "rest/api/user/current"
|
||||
params = {}
|
||||
if expand:
|
||||
params["expand"] = expand
|
||||
try:
|
||||
response = self.get(url, params=params)
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 403:
|
||||
raise ApiPermissionError(
|
||||
"The calling user does not have permission", reason=e
|
||||
)
|
||||
raise
|
||||
return response
|
||||
|
||||
def _wrap_methods(self) -> None:
|
||||
"""
|
||||
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
|
||||
@@ -175,7 +141,7 @@ class OnyxConfluence(Confluence):
|
||||
|
||||
def _paginate_url(
|
||||
self, url_suffix: str, limit: int | None = None
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
"""
|
||||
This will paginate through the top level query.
|
||||
"""
|
||||
@@ -187,43 +153,46 @@ class OnyxConfluence(Confluence):
|
||||
|
||||
while url_suffix:
|
||||
try:
|
||||
logger.debug(f"Making confluence call to {url_suffix}")
|
||||
next_response = self.get(url_suffix)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in confluence call to {url_suffix}")
|
||||
|
||||
# If the problematic expansion is in the url, replace it
|
||||
# with the replacement expansion and try again
|
||||
# If that fails, raise the error
|
||||
if _PROBLEMATIC_EXPANSIONS not in url_suffix:
|
||||
logger.exception(f"Error in confluence call to {url_suffix}")
|
||||
raise e
|
||||
logger.warning(
|
||||
f"Replacing {_PROBLEMATIC_EXPANSIONS} with {_REPLACEMENT_EXPANSIONS}"
|
||||
" and trying again."
|
||||
)
|
||||
url_suffix = url_suffix.replace(
|
||||
_PROBLEMATIC_EXPANSIONS,
|
||||
_REPLACEMENT_EXPANSIONS,
|
||||
)
|
||||
continue
|
||||
|
||||
# yield the results individually
|
||||
yield from next_response.get("results", [])
|
||||
|
||||
logger.exception("Error in danswer_cql: \n")
|
||||
raise e
|
||||
yield next_response.get("results", [])
|
||||
url_suffix = next_response.get("_links", {}).get("next")
|
||||
|
||||
def paginated_cql_retrieval(
|
||||
def paginated_groups_retrieval(
|
||||
self,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
return self._paginate_url("rest/api/group", limit)
|
||||
|
||||
def paginated_group_members_retrieval(
|
||||
self,
|
||||
group_name: str,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
group_name = quote(group_name)
|
||||
return self._paginate_url(f"rest/api/group/{group_name}/member", limit)
|
||||
|
||||
def paginated_cql_user_retrieval(
|
||||
self,
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
The content/search endpoint can be used to fetch pages, attachments, and comments.
|
||||
"""
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
yield from self._paginate_url(
|
||||
return self._paginate_url(
|
||||
f"rest/api/search/user?cql={cql}{expand_string}", limit
|
||||
)
|
||||
|
||||
def paginated_cql_page_retrieval(
|
||||
self,
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
return self._paginate_url(
|
||||
f"rest/api/content/search?cql={cql}{expand_string}", limit
|
||||
)
|
||||
|
||||
@@ -232,7 +201,7 @@ class OnyxConfluence(Confluence):
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
"""
|
||||
This function will paginate through the top level query first, then
|
||||
paginate through all of the expansions.
|
||||
@@ -252,121 +221,6 @@ class OnyxConfluence(Confluence):
|
||||
for item in data:
|
||||
_traverse_and_update(item)
|
||||
|
||||
for confluence_object in self.paginated_cql_retrieval(cql, expand, limit):
|
||||
_traverse_and_update(confluence_object)
|
||||
yield confluence_object
|
||||
|
||||
def paginated_cql_user_retrieval(
|
||||
self,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
The search/user endpoint can be used to fetch users.
|
||||
It's a seperate endpoint from the content/search endpoint used only for users.
|
||||
Otherwise it's very similar to the content/search endpoint.
|
||||
"""
|
||||
cql = "type=user"
|
||||
url = "rest/api/search/user" if self.cloud else "rest/api/search"
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
url += f"?cql={cql}{expand_string}"
|
||||
yield from self._paginate_url(url, limit)
|
||||
|
||||
def paginated_groups_by_user_retrieval(
|
||||
self,
|
||||
user: dict[str, Any],
|
||||
limit: int | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
This is not an SQL like query.
|
||||
It's a confluence specific endpoint that can be used to fetch groups.
|
||||
"""
|
||||
user_field = "accountId" if self.cloud else "key"
|
||||
user_value = user["accountId"] if self.cloud else user["userKey"]
|
||||
# Server uses userKey (but calls it key during the API call), Cloud uses accountId
|
||||
user_query = f"{user_field}={quote(user_value)}"
|
||||
|
||||
url = f"rest/api/user/memberof?{user_query}"
|
||||
yield from self._paginate_url(url, limit)
|
||||
|
||||
def paginated_groups_retrieval(
|
||||
self,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
This is not an SQL like query.
|
||||
It's a confluence specific endpoint that can be used to fetch groups.
|
||||
"""
|
||||
yield from self._paginate_url("rest/api/group", limit)
|
||||
|
||||
def paginated_group_members_retrieval(
|
||||
self,
|
||||
group_name: str,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
This is not an SQL like query.
|
||||
It's a confluence specific endpoint that can be used to fetch the members of a group.
|
||||
THIS DOESN'T WORK FOR SERVER because it breaks when there is a slash in the group name.
|
||||
E.g. neither "test/group" nor "test%2Fgroup" works for confluence.
|
||||
"""
|
||||
group_name = quote(group_name)
|
||||
yield from self._paginate_url(f"rest/api/group/{group_name}/member", limit)
|
||||
|
||||
|
||||
def _validate_connector_configuration(
|
||||
credentials: dict[str, Any],
|
||||
is_cloud: bool,
|
||||
wiki_base: str,
|
||||
) -> None:
|
||||
# test connection with direct client, no retries
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
api_version="cloud" if is_cloud else "latest",
|
||||
url=wiki_base.rstrip("/"),
|
||||
username=credentials["confluence_username"] if is_cloud else None,
|
||||
password=credentials["confluence_access_token"] if is_cloud else None,
|
||||
token=credentials["confluence_access_token"] if not is_cloud else None,
|
||||
backoff_and_retry=True,
|
||||
max_backoff_retries=6,
|
||||
max_backoff_seconds=10,
|
||||
)
|
||||
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
|
||||
|
||||
# uncomment the following for testing
|
||||
# the following is an attempt to retrieve the user's timezone
|
||||
# Unfornately, all data is returned in UTC regardless of the user's time zone
|
||||
# even tho CQL parses incoming times based on the user's time zone
|
||||
# space_key = spaces["results"][0]["key"]
|
||||
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
|
||||
|
||||
if not spaces:
|
||||
raise RuntimeError(
|
||||
f"No spaces found at {wiki_base}! "
|
||||
"Check your credentials and wiki_base and make sure "
|
||||
"is_cloud is set correctly."
|
||||
)
|
||||
|
||||
|
||||
def build_confluence_client(
|
||||
credentials: dict[str, Any],
|
||||
is_cloud: bool,
|
||||
wiki_base: str,
|
||||
) -> OnyxConfluence:
|
||||
_validate_connector_configuration(
|
||||
credentials=credentials,
|
||||
is_cloud=is_cloud,
|
||||
wiki_base=wiki_base,
|
||||
)
|
||||
return OnyxConfluence(
|
||||
api_version="cloud" if is_cloud else "latest",
|
||||
# Remove trailing slash from wiki_base if present
|
||||
url=wiki_base.rstrip("/"),
|
||||
# passing in username causes issues for Confluence data center
|
||||
username=credentials["confluence_username"] if is_cloud else None,
|
||||
password=credentials["confluence_access_token"] if is_cloud else None,
|
||||
token=credentials["confluence_access_token"] if not is_cloud else None,
|
||||
backoff_and_retry=True,
|
||||
max_backoff_retries=10,
|
||||
max_backoff_seconds=60,
|
||||
cloud=is_cloud,
|
||||
)
|
||||
for results in self.paginated_cql_page_retrieval(cql, expand, limit):
|
||||
_traverse_and_update(results)
|
||||
yield results
|
||||
|
||||
@@ -2,7 +2,6 @@ import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
import bs4
|
||||
|
||||
@@ -32,11 +31,7 @@ def get_user_email_from_username__server(
|
||||
response = confluence_client.get_mobile_parameters(user_name)
|
||||
email = response.get("email")
|
||||
except Exception:
|
||||
# For now, we'll just return a string that indicates failure
|
||||
# We may want to revert to returning None in the future
|
||||
# email = None
|
||||
email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
|
||||
logger.warning(f"failed to get confluence email for {user_name}")
|
||||
email = None
|
||||
_USER_EMAIL_CACHE[user_name] = email
|
||||
return _USER_EMAIL_CACHE[user_name]
|
||||
|
||||
@@ -76,9 +71,7 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
|
||||
|
||||
|
||||
def extract_text_from_confluence_html(
|
||||
confluence_client: OnyxConfluence,
|
||||
confluence_object: dict[str, Any],
|
||||
fetched_titles: set[str],
|
||||
confluence_client: OnyxConfluence, confluence_object: dict[str, Any]
|
||||
) -> str:
|
||||
"""Parse a Confluence html page and replace the 'user Id' by the real
|
||||
User Display Name
|
||||
@@ -86,7 +79,7 @@ def extract_text_from_confluence_html(
|
||||
Args:
|
||||
confluence_object (dict): The confluence object as a dict
|
||||
confluence_client (Confluence): Confluence client
|
||||
fetched_titles (set[str]): The titles of the pages that have already been fetched
|
||||
|
||||
Returns:
|
||||
str: loaded and formated Confluence page
|
||||
"""
|
||||
@@ -107,93 +100,22 @@ def extract_text_from_confluence_html(
|
||||
continue
|
||||
# Include @ sign for tagging, more clear for LLM
|
||||
user.replaceWith("@" + _get_user(confluence_client, user_id))
|
||||
|
||||
for html_page_reference in soup.findAll("ac:structured-macro"):
|
||||
# Here, we only want to process page within page macros
|
||||
if html_page_reference.attrs.get("ac:name") != "include":
|
||||
continue
|
||||
|
||||
page_data = html_page_reference.find("ri:page")
|
||||
if not page_data:
|
||||
logger.warning(
|
||||
f"Skipping retrieval of {html_page_reference} because because page data is missing"
|
||||
)
|
||||
continue
|
||||
|
||||
page_title = page_data.attrs.get("ri:content-title")
|
||||
if not page_title:
|
||||
# only fetch pages that have a title
|
||||
logger.warning(
|
||||
f"Skipping retrieval of {html_page_reference} because it has no title"
|
||||
)
|
||||
continue
|
||||
|
||||
if page_title in fetched_titles:
|
||||
# prevent recursive fetching of pages
|
||||
logger.debug(f"Skipping {page_title} because it has already been fetched")
|
||||
continue
|
||||
|
||||
fetched_titles.add(page_title)
|
||||
|
||||
# Wrap this in a try-except because there are some pages that might not exist
|
||||
try:
|
||||
page_query = f"type=page and title='{quote(page_title)}'"
|
||||
|
||||
page_contents: dict[str, Any] | None = None
|
||||
# Confluence enforces title uniqueness, so we should only get one result here
|
||||
for page in confluence_client.paginated_cql_retrieval(
|
||||
cql=page_query,
|
||||
expand="body.storage.value",
|
||||
limit=1,
|
||||
):
|
||||
page_contents = page
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error getting page contents for object {confluence_object}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not page_contents:
|
||||
continue
|
||||
|
||||
text_from_page = extract_text_from_confluence_html(
|
||||
confluence_client=confluence_client,
|
||||
confluence_object=page_contents,
|
||||
fetched_titles=fetched_titles,
|
||||
)
|
||||
|
||||
html_page_reference.replaceWith(text_from_page)
|
||||
|
||||
for html_link_body in soup.findAll("ac:link-body"):
|
||||
# This extracts the text from inline links in the page so they can be
|
||||
# represented in the document text as plain text
|
||||
try:
|
||||
text_from_link = html_link_body.text
|
||||
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing ac:link-body: {e}")
|
||||
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
|
||||
return attachment["metadata"]["mediaType"] not in [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"video/mp4",
|
||||
"video/quicktime",
|
||||
]
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: OnyxConfluence,
|
||||
attachment: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if not validate_attachment_filetype(attachment):
|
||||
if attachment["metadata"]["mediaType"] in [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"video/mp4",
|
||||
"video/quicktime",
|
||||
]:
|
||||
return None
|
||||
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
@@ -231,9 +153,7 @@ def attachment_to_content(
|
||||
return extracted_text
|
||||
|
||||
|
||||
def build_confluence_document_id(
|
||||
base_url: str, content_url: str, is_cloud: bool
|
||||
) -> str:
|
||||
def build_confluence_document_id(base_url: str, content_url: str) -> str:
|
||||
"""For confluence, the document id is the page url for a page based document
|
||||
or the attachment download url for an attachment based document
|
||||
|
||||
@@ -244,12 +164,10 @@ def build_confluence_document_id(
|
||||
Returns:
|
||||
str: The document id
|
||||
"""
|
||||
if is_cloud and not base_url.endswith("/wiki"):
|
||||
base_url += "/wiki"
|
||||
return f"{base_url}{content_url}"
|
||||
|
||||
|
||||
def _extract_referenced_attachment_names(page_text: str) -> list[str]:
|
||||
def extract_referenced_attachment_names(page_text: str) -> list[str]:
|
||||
"""Parse a Confluence html page to generate a list of current
|
||||
attachments in use
|
||||
|
||||
@@ -277,3 +195,20 @@ def datetime_from_string(datetime_string: str) -> datetime:
|
||||
datetime_object = datetime_object.astimezone(timezone.utc)
|
||||
|
||||
return datetime_object
|
||||
|
||||
|
||||
def build_confluence_client(
|
||||
credentials_json: dict[str, Any], is_cloud: bool, wiki_base: str
|
||||
) -> OnyxConfluence:
|
||||
return OnyxConfluence(
|
||||
api_version="cloud" if is_cloud else "latest",
|
||||
# Remove trailing slash from wiki_base if present
|
||||
url=wiki_base.rstrip("/"),
|
||||
# passing in username causes issues for Confluence data center
|
||||
username=credentials_json["confluence_username"] if is_cloud else None,
|
||||
password=credentials_json["confluence_access_token"] if is_cloud else None,
|
||||
token=credentials_json["confluence_access_token"] if not is_cloud else None,
|
||||
backoff_and_retry=True,
|
||||
max_backoff_retries=60,
|
||||
max_backoff_seconds=60,
|
||||
)
|
||||
|
||||
@@ -23,16 +23,7 @@ def datetime_to_utc(dt: datetime) -> datetime:
|
||||
|
||||
|
||||
def time_str_to_utc(datetime_str: str) -> datetime:
|
||||
try:
|
||||
dt = parse(datetime_str)
|
||||
except ValueError:
|
||||
# Handle malformed timezone by attempting to fix common format issues
|
||||
if "0000" in datetime_str:
|
||||
# Convert "0000" to "+0000" for proper timezone parsing
|
||||
fixed_dt_str = datetime_str.replace(" 0000", " +0000")
|
||||
dt = parse(fixed_dt_str)
|
||||
else:
|
||||
raise
|
||||
dt = parse(datetime_str)
|
||||
return datetime_to_utc(dt)
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from jira import JIRA
|
||||
from jira.resources import Issue
|
||||
@@ -12,93 +12,129 @@ from danswer.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.danswer_jira.utils import best_effort_basic_expert_info
|
||||
from danswer.connectors.danswer_jira.utils import best_effort_get_field_from_issue
|
||||
from danswer.connectors.danswer_jira.utils import build_jira_client
|
||||
from danswer.connectors.danswer_jira.utils import build_jira_url
|
||||
from danswer.connectors.danswer_jira.utils import extract_jira_project
|
||||
from danswer.connectors.danswer_jira.utils import extract_text_from_adf
|
||||
from danswer.connectors.danswer_jira.utils import get_comment_strs
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.interfaces import SlimConnector
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.models import SlimDocument
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
PROJECT_URL_PAT = "projects"
|
||||
JIRA_API_VERSION = os.environ.get("JIRA_API_VERSION") or "2"
|
||||
_JIRA_SLIM_PAGE_SIZE = 500
|
||||
_JIRA_FULL_PAGE_SIZE = 50
|
||||
|
||||
|
||||
def _paginate_jql_search(
|
||||
jira_client: JIRA,
|
||||
jql: str,
|
||||
max_results: int,
|
||||
fields: str | None = None,
|
||||
) -> Iterable[Issue]:
|
||||
start = 0
|
||||
while True:
|
||||
logger.debug(
|
||||
f"Fetching Jira issues with JQL: {jql}, "
|
||||
f"starting at {start}, max results: {max_results}"
|
||||
)
|
||||
issues = jira_client.search_issues(
|
||||
jql_str=jql,
|
||||
startAt=start,
|
||||
maxResults=max_results,
|
||||
fields=fields,
|
||||
)
|
||||
def extract_jira_project(url: str) -> tuple[str, str]:
|
||||
parsed_url = urlparse(url)
|
||||
jira_base = parsed_url.scheme + "://" + parsed_url.netloc
|
||||
|
||||
for issue in issues:
|
||||
if isinstance(issue, Issue):
|
||||
yield issue
|
||||
else:
|
||||
raise Exception(f"Found Jira object not of type Issue: {issue}")
|
||||
# Split the path by '/' and find the position of 'projects' to get the project name
|
||||
split_path = parsed_url.path.split("/")
|
||||
if PROJECT_URL_PAT in split_path:
|
||||
project_pos = split_path.index(PROJECT_URL_PAT)
|
||||
if len(split_path) > project_pos + 1:
|
||||
jira_project = split_path[project_pos + 1]
|
||||
else:
|
||||
raise ValueError("No project name found in the URL")
|
||||
else:
|
||||
raise ValueError("'projects' not found in the URL")
|
||||
|
||||
if len(issues) < max_results:
|
||||
break
|
||||
return jira_base, jira_project
|
||||
|
||||
start += max_results
|
||||
|
||||
def extract_text_from_adf(adf: dict | None) -> str:
|
||||
"""Extracts plain text from Atlassian Document Format:
|
||||
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
|
||||
|
||||
WARNING: This function is incomplete and will e.g. skip lists!
|
||||
"""
|
||||
texts = []
|
||||
if adf is not None and "content" in adf:
|
||||
for block in adf["content"]:
|
||||
if "content" in block:
|
||||
for item in block["content"]:
|
||||
if item["type"] == "text":
|
||||
texts.append(item["text"])
|
||||
return " ".join(texts)
|
||||
|
||||
|
||||
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
|
||||
if hasattr(jira_issue.fields, field):
|
||||
return getattr(jira_issue.fields, field)
|
||||
|
||||
try:
|
||||
return jira_issue.raw["fields"][field]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _get_comment_strs(
|
||||
jira: Issue, comment_email_blacklist: tuple[str, ...] = ()
|
||||
) -> list[str]:
|
||||
comment_strs = []
|
||||
for comment in jira.fields.comment.comments:
|
||||
try:
|
||||
body_text = (
|
||||
comment.body
|
||||
if JIRA_API_VERSION == "2"
|
||||
else extract_text_from_adf(comment.raw["body"])
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(comment, "author")
|
||||
and hasattr(comment.author, "emailAddress")
|
||||
and comment.author.emailAddress in comment_email_blacklist
|
||||
):
|
||||
continue # Skip adding comment if author's email is in blacklist
|
||||
|
||||
comment_strs.append(body_text)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process comment due to an error: {e}")
|
||||
continue
|
||||
|
||||
return comment_strs
|
||||
|
||||
|
||||
def fetch_jira_issues_batch(
|
||||
jira_client: JIRA,
|
||||
jql: str,
|
||||
batch_size: int,
|
||||
start_index: int,
|
||||
jira_client: JIRA,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
comment_email_blacklist: tuple[str, ...] = (),
|
||||
labels_to_skip: set[str] | None = None,
|
||||
) -> Iterable[Document]:
|
||||
for issue in _paginate_jql_search(
|
||||
jira_client=jira_client,
|
||||
jql=jql,
|
||||
max_results=batch_size,
|
||||
):
|
||||
if labels_to_skip:
|
||||
if any(label in issue.fields.labels for label in labels_to_skip):
|
||||
logger.info(
|
||||
f"Skipping {issue.key} because it has a label to skip. Found "
|
||||
f"labels: {issue.fields.labels}. Labels to skip: {labels_to_skip}."
|
||||
)
|
||||
continue
|
||||
) -> tuple[list[Document], int]:
|
||||
doc_batch = []
|
||||
|
||||
batch = jira_client.search_issues(
|
||||
jql,
|
||||
startAt=start_index,
|
||||
maxResults=batch_size,
|
||||
)
|
||||
|
||||
for jira in batch:
|
||||
if type(jira) != Issue:
|
||||
logger.warning(f"Found Jira object not of type Issue {jira}")
|
||||
continue
|
||||
|
||||
if labels_to_skip and any(
|
||||
label in jira.fields.labels for label in labels_to_skip
|
||||
):
|
||||
logger.info(
|
||||
f"Skipping {jira.key} because it has a label to skip. Found "
|
||||
f"labels: {jira.fields.labels}. Labels to skip: {labels_to_skip}."
|
||||
)
|
||||
continue
|
||||
|
||||
description = (
|
||||
issue.fields.description
|
||||
jira.fields.description
|
||||
if JIRA_API_VERSION == "2"
|
||||
else extract_text_from_adf(issue.raw["fields"]["description"])
|
||||
)
|
||||
comments = get_comment_strs(
|
||||
issue=issue,
|
||||
comment_email_blacklist=comment_email_blacklist,
|
||||
else extract_text_from_adf(jira.raw["fields"]["description"])
|
||||
)
|
||||
comments = _get_comment_strs(jira, comment_email_blacklist)
|
||||
ticket_content = f"{description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments if comment]
|
||||
)
|
||||
@@ -106,53 +142,66 @@ def fetch_jira_issues_batch(
|
||||
# Check ticket size
|
||||
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
|
||||
logger.info(
|
||||
f"Skipping {issue.key} because it exceeds the maximum size of "
|
||||
f"Skipping {jira.key} because it exceeds the maximum size of "
|
||||
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
|
||||
)
|
||||
continue
|
||||
|
||||
page_url = f"{jira_client.client_info()}/browse/{issue.key}"
|
||||
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
|
||||
|
||||
people = set()
|
||||
try:
|
||||
creator = best_effort_get_field_from_issue(issue, "creator")
|
||||
if basic_expert_info := best_effort_basic_expert_info(creator):
|
||||
people.add(basic_expert_info)
|
||||
people.add(
|
||||
BasicExpertInfo(
|
||||
display_name=jira.fields.creator.displayName,
|
||||
email=jira.fields.creator.emailAddress,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
|
||||
try:
|
||||
assignee = best_effort_get_field_from_issue(issue, "assignee")
|
||||
if basic_expert_info := best_effort_basic_expert_info(assignee):
|
||||
people.add(basic_expert_info)
|
||||
people.add(
|
||||
BasicExpertInfo(
|
||||
display_name=jira.fields.assignee.displayName, # type: ignore
|
||||
email=jira.fields.assignee.emailAddress, # type: ignore
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
|
||||
metadata_dict = {}
|
||||
if priority := best_effort_get_field_from_issue(issue, "priority"):
|
||||
priority = best_effort_get_field_from_issue(jira, "priority")
|
||||
if priority:
|
||||
metadata_dict["priority"] = priority.name
|
||||
if status := best_effort_get_field_from_issue(issue, "status"):
|
||||
status = best_effort_get_field_from_issue(jira, "status")
|
||||
if status:
|
||||
metadata_dict["status"] = status.name
|
||||
if resolution := best_effort_get_field_from_issue(issue, "resolution"):
|
||||
resolution = best_effort_get_field_from_issue(jira, "resolution")
|
||||
if resolution:
|
||||
metadata_dict["resolution"] = resolution.name
|
||||
if labels := best_effort_get_field_from_issue(issue, "labels"):
|
||||
labels = best_effort_get_field_from_issue(jira, "labels")
|
||||
if labels:
|
||||
metadata_dict["label"] = labels
|
||||
|
||||
yield Document(
|
||||
id=page_url,
|
||||
sections=[Section(link=page_url, text=ticket_content)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=issue.fields.summary,
|
||||
doc_updated_at=time_str_to_utc(issue.fields.updated),
|
||||
primary_owners=list(people) or None,
|
||||
# TODO add secondary_owners (commenters) if needed
|
||||
metadata=metadata_dict,
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=page_url,
|
||||
sections=[Section(link=page_url, text=ticket_content)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=jira.fields.summary,
|
||||
doc_updated_at=time_str_to_utc(jira.fields.updated),
|
||||
primary_owners=list(people) or None,
|
||||
# TODO add secondary_owners (commenters) if needed
|
||||
metadata=metadata_dict,
|
||||
)
|
||||
)
|
||||
return doc_batch, len(batch)
|
||||
|
||||
|
||||
class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class JiraConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
jira_project_url: str,
|
||||
@@ -164,8 +213,8 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.jira_base, self._jira_project = extract_jira_project(jira_project_url)
|
||||
self._jira_client: JIRA | None = None
|
||||
self.jira_base, self.jira_project = extract_jira_project(jira_project_url)
|
||||
self.jira_client: JIRA | None = None
|
||||
self._comment_email_blacklist = comment_email_blacklist or []
|
||||
|
||||
self.labels_to_skip = set(labels_to_skip)
|
||||
@@ -174,45 +223,54 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def comment_email_blacklist(self) -> tuple:
|
||||
return tuple(email.strip() for email in self._comment_email_blacklist)
|
||||
|
||||
@property
|
||||
def jira_client(self) -> JIRA:
|
||||
if self._jira_client is None:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
return self._jira_client
|
||||
|
||||
@property
|
||||
def quoted_jira_project(self) -> str:
|
||||
# Quote the project name to handle reserved words
|
||||
return f'"{self._jira_project}"'
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._jira_client = build_jira_client(
|
||||
credentials=credentials,
|
||||
jira_base=self.jira_base,
|
||||
)
|
||||
api_token = credentials["jira_api_token"]
|
||||
# if user provide an email we assume it's cloud
|
||||
if "jira_user_email" in credentials:
|
||||
email = credentials["jira_user_email"]
|
||||
self.jira_client = JIRA(
|
||||
basic_auth=(email, api_token),
|
||||
server=self.jira_base,
|
||||
options={"rest_api_version": JIRA_API_VERSION},
|
||||
)
|
||||
else:
|
||||
self.jira_client = JIRA(
|
||||
token_auth=api_token,
|
||||
server=self.jira_base,
|
||||
options={"rest_api_version": JIRA_API_VERSION},
|
||||
)
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
jql = f"project = {self.quoted_jira_project}"
|
||||
if self.jira_client is None:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
document_batch = []
|
||||
for doc in fetch_jira_issues_batch(
|
||||
jira_client=self.jira_client,
|
||||
jql=jql,
|
||||
batch_size=_JIRA_FULL_PAGE_SIZE,
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
):
|
||||
document_batch.append(doc)
|
||||
if len(document_batch) >= self.batch_size:
|
||||
yield document_batch
|
||||
document_batch = []
|
||||
# Quote the project name to handle reserved words
|
||||
quoted_project = f'"{self.jira_project}"'
|
||||
start_ind = 0
|
||||
while True:
|
||||
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
|
||||
jql=f"project = {quoted_project}",
|
||||
start_index=start_ind,
|
||||
jira_client=self.jira_client,
|
||||
batch_size=self.batch_size,
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
)
|
||||
|
||||
yield document_batch
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
start_ind += fetched_batch_size
|
||||
if fetched_batch_size < self.batch_size:
|
||||
break
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.jira_client is None:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
start_date_str = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
@@ -220,54 +278,31 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
|
||||
# Quote the project name to handle reserved words
|
||||
quoted_project = f'"{self.jira_project}"'
|
||||
jql = (
|
||||
f"project = {self.quoted_jira_project} AND "
|
||||
f"project = {quoted_project} AND "
|
||||
f"updated >= '{start_date_str}' AND "
|
||||
f"updated <= '{end_date_str}'"
|
||||
)
|
||||
|
||||
document_batch = []
|
||||
for doc in fetch_jira_issues_batch(
|
||||
jira_client=self.jira_client,
|
||||
jql=jql,
|
||||
batch_size=_JIRA_FULL_PAGE_SIZE,
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
):
|
||||
document_batch.append(doc)
|
||||
if len(document_batch) >= self.batch_size:
|
||||
yield document_batch
|
||||
document_batch = []
|
||||
|
||||
yield document_batch
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
jql = f"project = {self.quoted_jira_project}"
|
||||
|
||||
slim_doc_batch = []
|
||||
for issue in _paginate_jql_search(
|
||||
jira_client=self.jira_client,
|
||||
jql=jql,
|
||||
max_results=_JIRA_SLIM_PAGE_SIZE,
|
||||
fields="key",
|
||||
):
|
||||
issue_key = best_effort_get_field_from_issue(issue, "key")
|
||||
id = build_jira_url(self.jira_client, issue_key)
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=id,
|
||||
perm_sync_data=None,
|
||||
)
|
||||
start_ind = 0
|
||||
while True:
|
||||
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
|
||||
jql=jql,
|
||||
start_index=start_ind,
|
||||
jira_client=self.jira_client,
|
||||
batch_size=self.batch_size,
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
)
|
||||
if len(slim_doc_batch) >= _JIRA_SLIM_PAGE_SIZE:
|
||||
yield slim_doc_batch
|
||||
slim_doc_batch = []
|
||||
|
||||
yield slim_doc_batch
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
start_ind += fetched_batch_size
|
||||
if fetched_batch_size < self.batch_size:
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,136 +1,17 @@
|
||||
"""Module with custom fields processing functions"""
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from jira import JIRA
|
||||
from jira.resources import CustomFieldOption
|
||||
from jira.resources import Issue
|
||||
from jira.resources import User
|
||||
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
PROJECT_URL_PAT = "projects"
|
||||
JIRA_API_VERSION = os.environ.get("JIRA_API_VERSION") or "2"
|
||||
|
||||
|
||||
def best_effort_basic_expert_info(obj: Any) -> BasicExpertInfo | None:
|
||||
display_name = None
|
||||
email = None
|
||||
if hasattr(obj, "display_name"):
|
||||
display_name = obj.display_name
|
||||
else:
|
||||
display_name = obj.get("displayName")
|
||||
|
||||
if hasattr(obj, "emailAddress"):
|
||||
email = obj.emailAddress
|
||||
else:
|
||||
email = obj.get("emailAddress")
|
||||
|
||||
if not email and not display_name:
|
||||
return None
|
||||
|
||||
return BasicExpertInfo(display_name=display_name, email=email)
|
||||
|
||||
|
||||
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
|
||||
if hasattr(jira_issue.fields, field):
|
||||
return getattr(jira_issue.fields, field)
|
||||
|
||||
try:
|
||||
return jira_issue.raw["fields"][field]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def extract_text_from_adf(adf: dict | None) -> str:
|
||||
"""Extracts plain text from Atlassian Document Format:
|
||||
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
|
||||
|
||||
WARNING: This function is incomplete and will e.g. skip lists!
|
||||
"""
|
||||
texts = []
|
||||
if adf is not None and "content" in adf:
|
||||
for block in adf["content"]:
|
||||
if "content" in block:
|
||||
for item in block["content"]:
|
||||
if item["type"] == "text":
|
||||
texts.append(item["text"])
|
||||
return " ".join(texts)
|
||||
|
||||
|
||||
def build_jira_url(jira_client: JIRA, issue_key: str) -> str:
|
||||
return f"{jira_client.client_info()}/browse/{issue_key}"
|
||||
|
||||
|
||||
def build_jira_client(credentials: dict[str, Any], jira_base: str) -> JIRA:
|
||||
api_token = credentials["jira_api_token"]
|
||||
# if user provide an email we assume it's cloud
|
||||
if "jira_user_email" in credentials:
|
||||
email = credentials["jira_user_email"]
|
||||
return JIRA(
|
||||
basic_auth=(email, api_token),
|
||||
server=jira_base,
|
||||
options={"rest_api_version": JIRA_API_VERSION},
|
||||
)
|
||||
else:
|
||||
return JIRA(
|
||||
token_auth=api_token,
|
||||
server=jira_base,
|
||||
options={"rest_api_version": JIRA_API_VERSION},
|
||||
)
|
||||
|
||||
|
||||
def extract_jira_project(url: str) -> tuple[str, str]:
|
||||
parsed_url = urlparse(url)
|
||||
jira_base = parsed_url.scheme + "://" + parsed_url.netloc
|
||||
|
||||
# Split the path by '/' and find the position of 'projects' to get the project name
|
||||
split_path = parsed_url.path.split("/")
|
||||
if PROJECT_URL_PAT in split_path:
|
||||
project_pos = split_path.index(PROJECT_URL_PAT)
|
||||
if len(split_path) > project_pos + 1:
|
||||
jira_project = split_path[project_pos + 1]
|
||||
else:
|
||||
raise ValueError("No project name found in the URL")
|
||||
else:
|
||||
raise ValueError("'projects' not found in the URL")
|
||||
|
||||
return jira_base, jira_project
|
||||
|
||||
|
||||
def get_comment_strs(
|
||||
issue: Issue, comment_email_blacklist: tuple[str, ...] = ()
|
||||
) -> list[str]:
|
||||
comment_strs = []
|
||||
for comment in issue.fields.comment.comments:
|
||||
try:
|
||||
body_text = (
|
||||
comment.body
|
||||
if JIRA_API_VERSION == "2"
|
||||
else extract_text_from_adf(comment.raw["body"])
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(comment, "author")
|
||||
and hasattr(comment.author, "emailAddress")
|
||||
and comment.author.emailAddress in comment_email_blacklist
|
||||
):
|
||||
continue # Skip adding comment if author's email is in blacklist
|
||||
|
||||
comment_strs.append(body_text)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process comment due to an error: {e}")
|
||||
continue
|
||||
|
||||
return comment_strs
|
||||
|
||||
|
||||
class CustomFieldExtractor:
|
||||
@staticmethod
|
||||
def _process_custom_field_value(value: Any) -> str:
|
||||
|
||||
@@ -16,8 +16,6 @@ from danswer.connectors.discourse.connector import DiscourseConnector
|
||||
from danswer.connectors.document360.connector import Document360Connector
|
||||
from danswer.connectors.dropbox.connector import DropboxConnector
|
||||
from danswer.connectors.file.connector import LocalFileConnector
|
||||
from danswer.connectors.fireflies.connector import FirefliesConnector
|
||||
from danswer.connectors.freshdesk.connector import FreshdeskConnector
|
||||
from danswer.connectors.github.connector import GithubConnector
|
||||
from danswer.connectors.gitlab.connector import GitlabConnector
|
||||
from danswer.connectors.gmail.connector import GmailConnector
|
||||
@@ -101,8 +99,6 @@ def identify_connector_class(
|
||||
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.OCI_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.XENFORO: XenforoConnector,
|
||||
DocumentSource.FRESHDESK: FreshdeskConnector,
|
||||
DocumentSource.FIREFLIES: FirefliesConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
|
||||
|
||||
@@ -27,8 +27,8 @@ from danswer.file_processing.extract_file_text import read_pdf_file
|
||||
from danswer.file_processing.extract_file_text import read_text_file
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -123,13 +123,9 @@ def _process_file(
|
||||
"filename",
|
||||
"file_display_name",
|
||||
"title",
|
||||
"connector_type",
|
||||
]
|
||||
}
|
||||
|
||||
source_type_str = all_metadata.get("connector_type")
|
||||
source_type = DocumentSource(source_type_str) if source_type_str else None
|
||||
|
||||
p_owner_names = all_metadata.get("primary_owners")
|
||||
s_owner_names = all_metadata.get("secondary_owners")
|
||||
p_owners = (
|
||||
@@ -149,7 +145,7 @@ def _process_file(
|
||||
sections=[
|
||||
Section(link=all_metadata.get("link"), text=file_content_raw.strip())
|
||||
],
|
||||
source=source_type or DocumentSource.FILE,
|
||||
source=DocumentSource.FILE,
|
||||
semantic_identifier=file_display_name,
|
||||
title=title,
|
||||
doc_updated_at=final_time_updated,
|
||||
|
||||
@@ -1,182 +0,0 @@
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_FIREFLIES_ID_PREFIX = "FIREFLIES_"
|
||||
|
||||
_FIREFLIES_API_URL = "https://api.fireflies.ai/graphql"
|
||||
|
||||
_FIREFLIES_TRANSCRIPT_QUERY_SIZE = 50 # Max page size is 50
|
||||
|
||||
_FIREFLIES_API_QUERY = """
|
||||
query Transcripts($fromDate: DateTime, $toDate: DateTime, $limit: Int!, $skip: Int!) {
|
||||
transcripts(fromDate: $fromDate, toDate: $toDate, limit: $limit, skip: $skip) {
|
||||
id
|
||||
title
|
||||
host_email
|
||||
participants
|
||||
date
|
||||
transcript_url
|
||||
sentences {
|
||||
text
|
||||
speaker_name
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def _create_doc_from_transcript(transcript: dict) -> Document | None:
|
||||
meeting_text = ""
|
||||
sentences = transcript.get("sentences", [])
|
||||
if sentences:
|
||||
for sentence in sentences:
|
||||
meeting_text += sentence.get("speaker_name") or "Unknown Speaker"
|
||||
meeting_text += ": " + sentence.get("text", "") + "\n\n"
|
||||
else:
|
||||
return None
|
||||
|
||||
meeting_link = transcript["transcript_url"]
|
||||
|
||||
fireflies_id = _FIREFLIES_ID_PREFIX + transcript["id"]
|
||||
|
||||
meeting_title = transcript["title"] or "No Title"
|
||||
|
||||
meeting_date_unix = transcript["date"]
|
||||
meeting_date = datetime.fromtimestamp(meeting_date_unix / 1000, tz=timezone.utc)
|
||||
|
||||
meeting_host_email = transcript["host_email"]
|
||||
host_email_user_info = [BasicExpertInfo(email=meeting_host_email)]
|
||||
|
||||
meeting_participants_email_list = []
|
||||
for participant in transcript.get("participants", []):
|
||||
if participant != meeting_host_email and participant:
|
||||
meeting_participants_email_list.append(BasicExpertInfo(email=participant))
|
||||
|
||||
return Document(
|
||||
id=fireflies_id,
|
||||
sections=[
|
||||
Section(
|
||||
link=meeting_link,
|
||||
text=meeting_text,
|
||||
)
|
||||
],
|
||||
source=DocumentSource.FIREFLIES,
|
||||
semantic_identifier=meeting_title,
|
||||
metadata={},
|
||||
doc_updated_at=meeting_date,
|
||||
primary_owners=host_email_user_info,
|
||||
secondary_owners=meeting_participants_email_list,
|
||||
)
|
||||
|
||||
|
||||
class FirefliesConnector(PollConnector, LoadConnector):
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
|
||||
def load_credentials(self, credentials: dict[str, str]) -> None:
|
||||
api_key = credentials.get("fireflies_api_key")
|
||||
|
||||
if not isinstance(api_key, str):
|
||||
raise ConnectorMissingCredentialError(
|
||||
"The Fireflies API key must be a string"
|
||||
)
|
||||
|
||||
self.api_key = api_key
|
||||
|
||||
return None
|
||||
|
||||
def _fetch_transcripts(
|
||||
self, start_datetime: str | None = None, end_datetime: str | None = None
|
||||
) -> Iterator[List[dict]]:
|
||||
if self.api_key is None:
|
||||
raise ConnectorMissingCredentialError("Missing API key")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + self.api_key,
|
||||
}
|
||||
|
||||
skip = 0
|
||||
variables: dict[str, int | str] = {
|
||||
"limit": _FIREFLIES_TRANSCRIPT_QUERY_SIZE,
|
||||
}
|
||||
|
||||
if start_datetime:
|
||||
variables["fromDate"] = start_datetime
|
||||
if end_datetime:
|
||||
variables["toDate"] = end_datetime
|
||||
|
||||
while True:
|
||||
variables["skip"] = skip
|
||||
response = requests.post(
|
||||
_FIREFLIES_API_URL,
|
||||
headers=headers,
|
||||
json={"query": _FIREFLIES_API_QUERY, "variables": variables},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
if response.status_code == 204:
|
||||
break
|
||||
|
||||
recieved_transcripts = response.json()
|
||||
parsed_transcripts = recieved_transcripts.get("data", {}).get(
|
||||
"transcripts", []
|
||||
)
|
||||
|
||||
yield parsed_transcripts
|
||||
|
||||
if len(parsed_transcripts) < _FIREFLIES_TRANSCRIPT_QUERY_SIZE:
|
||||
break
|
||||
|
||||
skip += _FIREFLIES_TRANSCRIPT_QUERY_SIZE
|
||||
|
||||
def _process_transcripts(
|
||||
self, start: str | None = None, end: str | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch: List[Document] = []
|
||||
|
||||
for transcript_batch in self._fetch_transcripts(start, end):
|
||||
for transcript in transcript_batch:
|
||||
if doc := _create_doc_from_transcript(transcript):
|
||||
doc_batch.append(doc)
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._process_transcripts()
|
||||
|
||||
def poll_source(
|
||||
self, start_unixtime: SecondsSinceUnixEpoch, end_unixtime: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_datetime = datetime.fromtimestamp(
|
||||
start_unixtime, tz=timezone.utc
|
||||
).strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
end_datetime = datetime.fromtimestamp(end_unixtime, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%dT%H:%M:%S.000Z"
|
||||
)
|
||||
|
||||
yield from self._process_transcripts(start_datetime, end_datetime)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user