mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
1 Commits
cohere_def
...
remove_end
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
14f57d6475 |
@@ -3,61 +3,61 @@ name: Build and Push Backend Image on Tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
- '*'
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'danswer/danswer-backend-cloud' || 'danswer/danswer-backend' }}
|
||||
REGISTRY_IMAGE: danswer/danswer-backend
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
# TODO: investigate a matrix build like the web container
|
||||
# TODO: investigate a matrix build like the web container
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Install build-essential
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential
|
||||
- name: Install build-essential
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential
|
||||
|
||||
- name: Backend Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
- name: Backend Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
with:
|
||||
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
trivyignores: ./backend/.trivyignore
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
trivyignores: ./backend/.trivyignore
|
||||
|
||||
@@ -4,12 +4,12 @@ name: Build and Push Cloud Web Image on Tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
- '*'
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-web-server-cloud
|
||||
REGISTRY_IMAGE: danswer/danswer-cloud-web-server
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on:
|
||||
@@ -28,11 +28,11 @@ jobs:
|
||||
- name: Prepare
|
||||
run: |
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
@@ -41,16 +41,16 @@ jobs:
|
||||
tags: |
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
|
||||
- name: Build and push by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
@@ -65,17 +65,17 @@ jobs:
|
||||
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
mkdir -p /tmp/digests
|
||||
digest="${{ steps.build.outputs.digest }}"
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
@@ -95,42 +95,42 @@ jobs:
|
||||
path: /tmp/digests
|
||||
pattern: digests-*
|
||||
merge-multiple: true
|
||||
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
|
||||
- name: Create manifest list and push
|
||||
working-directory: /tmp/digests
|
||||
run: |
|
||||
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
|
||||
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
|
||||
|
||||
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
|
||||
|
||||
- name: Inspect image
|
||||
run: |
|
||||
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
severity: 'CRITICAL,HIGH'
|
||||
|
||||
@@ -3,53 +3,53 @@ name: Build and Push Model Server Image on Tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
- '*'
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'danswer/danswer-model-server-cloud' || 'danswer/danswer-model-server' }}
|
||||
REGISTRY_IMAGE: danswer/danswer-model-server
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Model Server Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
- name: Model Server Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
with:
|
||||
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
|
||||
76
.github/workflows/nightly-scan-licenses.yml
vendored
76
.github/workflows/nightly-scan-licenses.yml
vendored
@@ -1,76 +0,0 @@
|
||||
# Scan for problematic software licenses
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
|
||||
name: 'Nightly - Scan licenses'
|
||||
on:
|
||||
# schedule:
|
||||
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
security-events: write
|
||||
|
||||
jobs:
|
||||
scan-licenses:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
|
||||
- name: Get explicit and transitive dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
pip freeze > requirements-all.txt
|
||||
|
||||
- name: Check python
|
||||
id: license_check_report
|
||||
uses: pilosus/action-pip-license-checker@v2
|
||||
with:
|
||||
requirements: 'requirements-all.txt'
|
||||
fail: 'Copyleft'
|
||||
exclude: '(?i)^(pylint|aio[-_]*).*'
|
||||
|
||||
- name: Print report
|
||||
if: ${{ always() }}
|
||||
run: echo "${{ steps.license_check_report.outputs.report }}"
|
||||
|
||||
- name: Install npm dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Run Trivy vulnerability scanner in repo mode
|
||||
uses: aquasecurity/trivy-action@0.28.0
|
||||
with:
|
||||
scan-type: fs
|
||||
scanners: license
|
||||
format: table
|
||||
# format: sarif
|
||||
# output: trivy-results.sarif
|
||||
severity: HIGH,CRITICAL
|
||||
|
||||
# - name: Upload Trivy scan results to GitHub Security tab
|
||||
# uses: github/codeql-action/upload-sarif@v3
|
||||
# with:
|
||||
# sarif_file: trivy-results.sarif
|
||||
11
.github/workflows/pr-Integration-tests.yml
vendored
11
.github/workflows/pr-Integration-tests.yml
vendored
@@ -210,18 +210,17 @@ jobs:
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
# save before stopping the containers so the logs can be captured
|
||||
- name: Stop Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
|
||||
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Stop Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
|
||||
@@ -1,20 +1,24 @@
|
||||
# This workflow is intentionally disabled while we're still working on it
|
||||
# It's close to ready, but a race condition needs to be fixed with
|
||||
# API server and Vespa startup, and it needs to have a way to build/test against
|
||||
# local containers
|
||||
|
||||
name: Helm - Lint and Test Charts
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
|
||||
jobs:
|
||||
helm-chart-check:
|
||||
lint-test:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
|
||||
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -24,7 +28,7 @@ jobs:
|
||||
version: v3.14.4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
@@ -41,31 +45,24 @@ jobs:
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
|
||||
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
run: |
|
||||
echo "default_branch: ${{ github.event.repository.default_branch }}"
|
||||
changed=$(ct list-changed --remote origin --target-branch ${{ github.event.repository.default_branch }} --chart-dirs deployment/helm/charts)
|
||||
echo "list-changed output: $changed"
|
||||
changed=$(ct list-changed --target-branch ${{ github.event.repository.default_branch }})
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --config ct.yaml --all
|
||||
# the following would lint only changed charts, but linting isn't expensive
|
||||
# run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --all --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.10.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct install --all --helm-extra-set-args="--set=nginx.enabled=false" --debug --config ct.yaml
|
||||
# the following would install only changed charts, but we only have one chart so
|
||||
# don't worry about that for now
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct install --all --config ct.yaml
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
@@ -18,11 +18,6 @@ env:
|
||||
# Jira
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
# Google
|
||||
GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR }}
|
||||
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
|
||||
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
|
||||
2
.github/workflows/pr-python-model-tests.yml
vendored
2
.github/workflows/pr-python-model-tests.yml
vendored
@@ -15,7 +15,7 @@ env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
jobs:
|
||||
model-check:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
|
||||
17
README.md
17
README.md
@@ -1,5 +1,4 @@
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/README.md"} -->
|
||||
<a name="readme-top"></a>
|
||||
|
||||
<h2 align="center">
|
||||
<a href="https://www.danswer.ai/"> <img width="50%" src="https://github.com/danswer-owners/danswer/blob/1fabd9372d66cd54238847197c33f091a724803b/DanswerWithName.png?raw=true)" /></a>
|
||||
@@ -128,19 +127,3 @@ To try the Danswer Enterprise Edition:
|
||||
|
||||
## 💡 Contributing
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
||||
## ⭐Star History
|
||||
|
||||
[](https://star-history.com/#danswer-ai/danswer&Date)
|
||||
|
||||
## ✨Contributors
|
||||
|
||||
<a href="https://github.com/aryn-ai/sycamore/graphs/contributors">
|
||||
<img alt="contributors" src="https://contrib.rocks/image?repo=danswer-ai/danswer"/>
|
||||
</a>
|
||||
|
||||
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
|
||||
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
|
||||
↑ Back to Top ↑
|
||||
</a>
|
||||
</p>
|
||||
|
||||
@@ -12,6 +12,7 @@ ARG DANSWER_VERSION=0.8-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
ARG CA_CERT_CONTENT=""
|
||||
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
# Install system dependencies
|
||||
@@ -38,6 +39,15 @@ RUN apt-get update && \
|
||||
apt-get clean
|
||||
|
||||
|
||||
# Conditionally write the CA certificate and update certificates
|
||||
RUN if [ -n "$CA_CERT_CONTENT" ]; then \
|
||||
echo "Adding custom CA certificate"; \
|
||||
echo "$CA_CERT_CONTENT" > /usr/local/share/ca-certificates/my-ca.crt && \
|
||||
chmod 644 /usr/local/share/ca-certificates/my-ca.crt && \
|
||||
update-ca-certificates; \
|
||||
else \
|
||||
echo "No custom CA certificate provided"; \
|
||||
fi
|
||||
|
||||
# Install Python dependencies
|
||||
# Remove py which is pulled in by retry, py is not needed and is a CVE
|
||||
@@ -77,6 +87,7 @@ RUN apt-get update && \
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
|
||||
|
||||
# Pre-downloading NLTK for setups with limited egress
|
||||
RUN python -c "import nltk; \
|
||||
nltk.download('stopwords', quiet=True); \
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
"""single tool call per message
|
||||
|
||||
Revision ID: 33cb72ea4d80
|
||||
Revises: 5b29123cd710
|
||||
Create Date: 2024-11-01 12:51:01.535003
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "33cb72ea4d80"
|
||||
down_revision = "5b29123cd710"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Step 1: Delete extraneous ToolCall entries
|
||||
# Keep only the ToolCall with the smallest 'id' for each 'message_id'
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM tool_call
|
||||
WHERE id NOT IN (
|
||||
SELECT MIN(id)
|
||||
FROM tool_call
|
||||
WHERE message_id IS NOT NULL
|
||||
GROUP BY message_id
|
||||
);
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 2: Add a unique constraint on message_id
|
||||
op.create_unique_constraint(
|
||||
constraint_name="uq_tool_call_message_id",
|
||||
table_name="tool_call",
|
||||
columns=["message_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Step 1: Drop the unique constraint on message_id
|
||||
op.drop_constraint(
|
||||
constraint_name="uq_tool_call_message_id",
|
||||
table_name="tool_call",
|
||||
type_="unique",
|
||||
)
|
||||
@@ -1,70 +0,0 @@
|
||||
"""nullable search settings for historic index attempts
|
||||
|
||||
Revision ID: 5b29123cd710
|
||||
Revises: 949b4a92a401
|
||||
Create Date: 2024-10-30 19:37:59.630704
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5b29123cd710"
|
||||
down_revision = "949b4a92a401"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the existing foreign key constraint
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Modify the column to be nullable
|
||||
op.alter_column(
|
||||
"index_attempt", "search_settings_id", existing_type=sa.INTEGER(), nullable=True
|
||||
)
|
||||
|
||||
# Add back the foreign key with ON DELETE SET NULL
|
||||
op.create_foreign_key(
|
||||
"fk_index_attempt_search_settings",
|
||||
"index_attempt",
|
||||
"search_settings",
|
||||
["search_settings_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Warning: This will delete all index attempts that don't have search settings
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM index_attempt
|
||||
WHERE search_settings_id IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop foreign key constraint
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Modify the column to be not nullable
|
||||
op.alter_column(
|
||||
"index_attempt",
|
||||
"search_settings_id",
|
||||
existing_type=sa.INTEGER(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Add back the foreign key without ON DELETE SET NULL
|
||||
op.create_foreign_key(
|
||||
"fk_index_attempt_search_settings",
|
||||
"index_attempt",
|
||||
"search_settings",
|
||||
["search_settings_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -288,15 +288,6 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# NOTE: you will lose all chat history. This is to satisfy the non-nullable constraints
|
||||
# below
|
||||
op.execute("DELETE FROM chat_feedback")
|
||||
op.execute("DELETE FROM chat_message__search_doc")
|
||||
op.execute("DELETE FROM document_retrieval_feedback")
|
||||
op.execute("DELETE FROM document_retrieval_feedback")
|
||||
op.execute("DELETE FROM chat_message")
|
||||
op.execute("DELETE FROM chat_session")
|
||||
|
||||
op.drop_constraint(
|
||||
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
|
||||
)
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
"""remove description from starter messages
|
||||
|
||||
Revision ID: b72ed7a5db0e
|
||||
Revises: 33cb72ea4d80
|
||||
Create Date: 2024-11-03 15:55:28.944408
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b72ed7a5db0e"
|
||||
down_revision = "33cb72ea4d80"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET starter_messages = (
|
||||
SELECT jsonb_agg(elem - 'description')
|
||||
FROM jsonb_array_elements(starter_messages) elem
|
||||
)
|
||||
WHERE starter_messages IS NOT NULL
|
||||
AND jsonb_typeof(starter_messages) = 'array'
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET starter_messages = (
|
||||
SELECT jsonb_agg(elem || '{"description": ""}')
|
||||
FROM jsonb_array_elements(starter_messages) elem
|
||||
)
|
||||
WHERE starter_messages IS NOT NULL
|
||||
AND jsonb_typeof(starter_messages) = 'array'
|
||||
"""
|
||||
)
|
||||
)
|
||||
@@ -1,29 +0,0 @@
|
||||
"""add recent assistants
|
||||
|
||||
Revision ID: c0fd6e4da83a
|
||||
Revises: b72ed7a5db0e
|
||||
Create Date: 2024-11-03 17:28:54.916618
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c0fd6e4da83a"
|
||||
down_revision = "b72ed7a5db0e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "recent_assistants")
|
||||
@@ -23,56 +23,6 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Delete chat messages and feedback first since they reference chat sessions
|
||||
# Get chat messages from sessions with null persona_id
|
||||
chat_messages_query = """
|
||||
SELECT id
|
||||
FROM chat_message
|
||||
WHERE chat_session_id IN (
|
||||
SELECT id
|
||||
FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
)
|
||||
"""
|
||||
|
||||
# Delete dependent records first
|
||||
op.execute(
|
||||
f"""
|
||||
DELETE FROM document_retrieval_feedback
|
||||
WHERE chat_message_id IN (
|
||||
{chat_messages_query}
|
||||
)
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
f"""
|
||||
DELETE FROM chat_message__search_doc
|
||||
WHERE chat_message_id IN (
|
||||
{chat_messages_query}
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Delete chat messages
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM chat_message
|
||||
WHERE chat_session_id IN (
|
||||
SELECT id
|
||||
FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Now we can safely delete the chat sessions
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"persona_id",
|
||||
|
||||
@@ -48,11 +48,11 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
|
||||
from httpx_oauth.oauth2 import BaseOAuth2
|
||||
from httpx_oauth.oauth2 import OAuth2Token
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import attributes
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.api_key import get_hashed_api_key_from_request
|
||||
from danswer.auth.invited_users import get_invited_users
|
||||
from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRole
|
||||
@@ -75,7 +75,6 @@ from danswer.configs.constants import AuthType
|
||||
from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from danswer.db.api_key import fetch_user_for_api_key
|
||||
from danswer.db.auth import get_access_token_db
|
||||
from danswer.db.auth import get_default_admin_user_emails
|
||||
from danswer.db.auth import get_user_count
|
||||
@@ -84,27 +83,24 @@ from danswer.db.auth import SQLAlchemyUserAdminDB
|
||||
from danswer.db.engine import get_async_session_with_tenant
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import OAuthAccount
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import UserTenantMapping
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import async_return_default_schema
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
@@ -194,6 +190,20 @@ def verify_email_domain(email: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def get_tenant_id_for_email(email: str) -> str:
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
# Implement logic to get tenant_id from the mapping table
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
result = db_session.execute(
|
||||
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
|
||||
)
|
||||
tenant_id = result.scalar_one_or_none()
|
||||
if tenant_id is None:
|
||||
raise exceptions.UserNotExists()
|
||||
return tenant_id
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
@@ -228,13 +238,19 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> User:
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=user_create.email,
|
||||
)
|
||||
try:
|
||||
tenant_id = (
|
||||
get_tenant_id_for_email(user_create.email)
|
||||
if MULTI_TENANT
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
except exceptions.UserNotExists:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
if not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="User does not belong to an organization"
|
||||
)
|
||||
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
@@ -255,7 +271,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
|
||||
user = None
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
except exceptions.UserAlreadyExists:
|
||||
@@ -276,9 +292,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
else:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
return user
|
||||
|
||||
async def oauth_callback(
|
||||
@@ -294,18 +308,19 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
) -> models.UOAP:
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=account_email,
|
||||
)
|
||||
# Get tenant_id from mapping table
|
||||
try:
|
||||
tenant_id = (
|
||||
get_tenant_id_for_email(account_email)
|
||||
if MULTI_TENANT
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
except exceptions.UserNotExists:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
# Proceed with the tenant context
|
||||
token = None
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
@@ -356,9 +371,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
# Explicitly set the Postgres schema for this session to ensure
|
||||
# OAuth account creation happens in the correct tenant schema
|
||||
await db_session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
|
||||
# Add OAuth account
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
await self.on_after_register(user, request)
|
||||
|
||||
else:
|
||||
@@ -438,13 +453,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
email = credentials.username
|
||||
|
||||
# Get tenant_id from mapping table
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=email,
|
||||
)
|
||||
tenant_id = get_tenant_id_for_email(email)
|
||||
if not tenant_id:
|
||||
# User not found in mapping
|
||||
self.password_helper.hash(credentials.password)
|
||||
@@ -468,7 +477,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
has_web_login = attributes.get_attribute(user, "has_web_login")
|
||||
|
||||
if not has_web_login:
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
||||
)
|
||||
|
||||
@@ -500,30 +510,19 @@ cookie_transport = CookieTransport(
|
||||
|
||||
# This strategy is used to add tenant_id to the JWT token
|
||||
class TenantAwareJWTStrategy(JWTStrategy):
|
||||
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=user.email,
|
||||
)
|
||||
|
||||
async def write_token(self, user: User) -> str:
|
||||
tenant_id = get_tenant_id_for_email(user.email)
|
||||
data = {
|
||||
"sub": str(user.id),
|
||||
"aud": self.token_audience,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
return data
|
||||
|
||||
async def write_token(self, user: User) -> str:
|
||||
data = await self._create_token_data(user)
|
||||
return generate_jwt(
|
||||
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
|
||||
)
|
||||
|
||||
|
||||
def get_jwt_strategy() -> TenantAwareJWTStrategy:
|
||||
def get_jwt_strategy() -> JWTStrategy:
|
||||
return TenantAwareJWTStrategy(
|
||||
secret=USER_AUTH_SECRET,
|
||||
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
|
||||
@@ -625,12 +624,14 @@ async def double_check_user(
|
||||
return None
|
||||
|
||||
if user is None:
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not authenticated.",
|
||||
)
|
||||
|
||||
if user_needs_to_be_verified() and not user.is_verified:
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
@@ -639,7 +640,8 @@ async def double_check_user(
|
||||
and user.oidc_expiry < datetime.now(timezone.utc)
|
||||
and not include_expired
|
||||
):
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
@@ -665,13 +667,15 @@ async def current_curator_or_admin_user(
|
||||
return None
|
||||
|
||||
if not user or not hasattr(user, "role"):
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not authenticated or lacks role information.",
|
||||
)
|
||||
|
||||
allowed_roles = {UserRole.GLOBAL_CURATOR, UserRole.CURATOR, UserRole.ADMIN}
|
||||
if user.role not in allowed_roles:
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not a curator or admin.",
|
||||
)
|
||||
|
||||
@@ -683,7 +687,8 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
|
||||
return None
|
||||
|
||||
if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN:
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User must be an admin to perform this action.",
|
||||
)
|
||||
|
||||
@@ -876,22 +881,3 @@ def get_oauth_router(
|
||||
return redirect_response
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def api_key_dep(
|
||||
request: Request, db_session: Session = Depends(get_session)
|
||||
) -> User | None:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return None
|
||||
|
||||
hashed_api_key = get_hashed_api_key_from_request(request)
|
||||
if not hashed_api_key:
|
||||
raise HTTPException(status_code=401, detail="Missing API key")
|
||||
|
||||
if hashed_api_key:
|
||||
user = fetch_user_for_api_key(hashed_api_key, db_session)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
return user
|
||||
|
||||
@@ -3,7 +3,6 @@ import multiprocessing
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import sentry_sdk
|
||||
from celery import Task
|
||||
from celery.app import trace
|
||||
@@ -12,22 +11,18 @@ from celery.states import READY_STATES
|
||||
from celery.utils.log import get_task_logger
|
||||
from celery.worker import strategy # type: ignore
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
|
||||
from danswer.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from danswer.redis.redis_document_set import RedisDocumentSet
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.redis.redis_usergroup import RedisUserGroup
|
||||
from danswer.utils.logger import ColoredFormatter
|
||||
from danswer.utils.logger import PlainFormatter
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -113,27 +108,29 @@ def on_task_postrun(
|
||||
if task_id.startswith(RedisDocumentSet.PREFIX):
|
||||
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
|
||||
if document_set_id is not None:
|
||||
rds = RedisDocumentSet(tenant_id, int(document_set_id))
|
||||
rds = RedisDocumentSet(int(document_set_id))
|
||||
r.srem(rds.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisUserGroup.PREFIX):
|
||||
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
|
||||
if usergroup_id is not None:
|
||||
rug = RedisUserGroup(tenant_id, int(usergroup_id))
|
||||
rug = RedisUserGroup(int(usergroup_id))
|
||||
r.srem(rug.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorDelete.PREFIX):
|
||||
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
|
||||
if task_id.startswith(RedisConnectorDeletion.PREFIX):
|
||||
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
RedisConnectorDelete.remove_from_taskset(int(cc_pair_id), task_id, r)
|
||||
rcd = RedisConnectorDeletion(int(cc_pair_id))
|
||||
r.srem(rcd.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorPrune.SUBTASK_PREFIX):
|
||||
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
|
||||
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
|
||||
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
RedisConnectorPrune.remove_from_taskset(int(cc_pair_id), task_id, r)
|
||||
rcp = RedisConnectorPruning(int(cc_pair_id))
|
||||
r.srem(rcp.taskset_key, task_id)
|
||||
return
|
||||
|
||||
|
||||
@@ -143,154 +140,27 @@ def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None
|
||||
|
||||
|
||||
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for redis to become ready subject to a hardcoded timeout.
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
|
||||
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
ready = False
|
||||
time_start = time.monotonic()
|
||||
logger.info("Redis: Readiness probe starting.")
|
||||
logger.info("Redis: Readiness check starting.")
|
||||
while True:
|
||||
try:
|
||||
if r.ping():
|
||||
ready = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Redis: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
if not ready:
|
||||
msg = (
|
||||
f"Redis: Readiness probe did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
logger.info("Redis: Readiness probe succeeded. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def wait_for_db(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for the db to become ready subject to a hardcoded timeout.
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
ready = False
|
||||
time_start = time.monotonic()
|
||||
logger.info("Database: Readiness probe starting.")
|
||||
while True:
|
||||
try:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
result = db_session.execute(text("SELECT NOW()")).scalar()
|
||||
if result:
|
||||
ready = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Database: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
if not ready:
|
||||
msg = (
|
||||
f"Database: Readiness probe did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
logger.info("Database: Readiness probe succeeded. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def wait_for_vespa(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for Vespa to become ready subject to a hardcoded timeout.
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
ready = False
|
||||
time_start = time.monotonic()
|
||||
logger.info("Vespa: Readiness probe starting.")
|
||||
while True:
|
||||
try:
|
||||
response = requests.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
|
||||
response.raise_for_status()
|
||||
|
||||
response_dict = response.json()
|
||||
if response_dict["status"]["code"] == "up":
|
||||
ready = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Vespa: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
if not ready:
|
||||
msg = (
|
||||
f"Vespa: Readiness probe did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
logger.info("Vespa: Readiness probe succeeded. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("Running as a secondary celery worker.")
|
||||
|
||||
# Set up variables for waiting on primary worker
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
r = get_redis_client(tenant_id=None)
|
||||
time_start = time.monotonic()
|
||||
|
||||
logger.info("Waiting for primary worker to be ready...")
|
||||
while True:
|
||||
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
|
||||
break
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
logger.info(
|
||||
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
msg = (
|
||||
f"Primary worker was not ready within the timeout. "
|
||||
f"Redis: Readiness check did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
@@ -298,7 +168,57 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
logger.info("Wait for primary worker completed successfully. Continuing...")
|
||||
logger.info("Redis: Readiness check succeeded. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
logger.info("Running as a secondary celery worker.")
|
||||
logger.info("Waiting for all tenant primary workers to be ready...")
|
||||
time_start = time.monotonic()
|
||||
|
||||
while True:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
# Check if we have a primary worker lock for each tenant
|
||||
all_tenants_ready = all(
|
||||
get_redis_client(tenant_id=tenant_id).exists(
|
||||
DanswerRedisLocks.PRIMARY_WORKER
|
||||
)
|
||||
for tenant_id in tenant_ids
|
||||
)
|
||||
|
||||
if all_tenants_ready:
|
||||
break
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
ready_tenants = sum(
|
||||
1
|
||||
for tenant_id in tenant_ids
|
||||
if get_redis_client(tenant_id=tenant_id).exists(
|
||||
DanswerRedisLocks.PRIMARY_WORKER
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Not all tenant primary workers are ready yet. "
|
||||
f"Ready tenants: {ready_tenants}/{len(tenant_ids)} "
|
||||
f"elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
msg = (
|
||||
f"Not all tenant primary workers were ready within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
logger.info("All tenant primary workers are ready. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
@@ -310,20 +230,26 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
if not celery_is_worker_primary(sender):
|
||||
return
|
||||
|
||||
if not sender.primary_worker_lock:
|
||||
if not hasattr(sender, "primary_worker_locks"):
|
||||
return
|
||||
|
||||
logger.info("Releasing primary worker lock.")
|
||||
lock = sender.primary_worker_lock
|
||||
try:
|
||||
if lock.owned():
|
||||
try:
|
||||
lock.release()
|
||||
sender.primary_worker_lock = None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to release primary worker lock: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check if primary worker lock is owned: {e}")
|
||||
for tenant_id, lock in sender.primary_worker_locks.items():
|
||||
try:
|
||||
if lock and lock.owned():
|
||||
logger.debug(f"Attempting to release lock for tenant {tenant_id}")
|
||||
try:
|
||||
lock.release()
|
||||
logger.debug(f"Successfully released lock for tenant {tenant_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to release lock for tenant {tenant_id}. Error: {str(e)}"
|
||||
)
|
||||
finally:
|
||||
sender.primary_worker_locks[tenant_id] = None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error checking lock status for tenant {tenant_id}. Error: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def on_setup_logging(
|
||||
|
||||
@@ -3,162 +3,28 @@ from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery.beat import PersistentScheduler # type: ignore
|
||||
from celery.signals import beat_init
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("danswer.background.celery.configs.beat")
|
||||
|
||||
|
||||
class DynamicTenantScheduler(PersistentScheduler):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
logger.info("Initializing DynamicTenantScheduler")
|
||||
super().__init__(*args, **kwargs)
|
||||
self._reload_interval = timedelta(minutes=2)
|
||||
self._last_reload = self.app.now() - self._reload_interval
|
||||
# Let the parent class handle store initialization
|
||||
self.setup_schedule()
|
||||
self._update_tenant_tasks()
|
||||
logger.info(f"Set reload interval to {self._reload_interval}")
|
||||
|
||||
def setup_schedule(self) -> None:
|
||||
logger.info("Setting up initial schedule")
|
||||
super().setup_schedule()
|
||||
logger.info("Initial schedule setup complete")
|
||||
|
||||
def tick(self) -> float:
|
||||
retval = super().tick()
|
||||
now = self.app.now()
|
||||
if (
|
||||
self._last_reload is None
|
||||
or (now - self._last_reload) > self._reload_interval
|
||||
):
|
||||
logger.info("Reload interval reached, initiating tenant task update")
|
||||
self._update_tenant_tasks()
|
||||
self._last_reload = now
|
||||
logger.info("Tenant task update completed, reset reload timer")
|
||||
return retval
|
||||
|
||||
def _update_tenant_tasks(self) -> None:
|
||||
logger.info("Starting tenant task update process")
|
||||
try:
|
||||
logger.info("Fetching all tenant IDs")
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
logger.info(f"Found {len(tenant_ids)} tenants")
|
||||
|
||||
logger.info("Fetching tasks to schedule")
|
||||
tasks_to_schedule = fetch_versioned_implementation(
|
||||
"danswer.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
|
||||
)
|
||||
|
||||
new_beat_schedule: dict[str, dict[str, Any]] = {}
|
||||
|
||||
current_schedule = self.schedule.items()
|
||||
|
||||
existing_tenants = set()
|
||||
for task_name, _ in current_schedule:
|
||||
if "-" in task_name:
|
||||
existing_tenants.add(task_name.split("-")[-1])
|
||||
logger.info(f"Found {len(existing_tenants)} existing tenants in schedule")
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if (
|
||||
IGNORED_SYNCING_TENANT_LIST
|
||||
and tenant_id in IGNORED_SYNCING_TENANT_LIST
|
||||
):
|
||||
logger.info(
|
||||
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
|
||||
)
|
||||
continue
|
||||
|
||||
if tenant_id not in existing_tenants:
|
||||
logger.info(f"Processing new tenant: {tenant_id}")
|
||||
|
||||
for task in tasks_to_schedule():
|
||||
task_name = f"{task['name']}-{tenant_id}"
|
||||
logger.debug(f"Creating task configuration for {task_name}")
|
||||
new_task = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"kwargs": {"tenant_id": tenant_id},
|
||||
}
|
||||
if options := task.get("options"):
|
||||
logger.debug(f"Adding options to task {task_name}: {options}")
|
||||
new_task["options"] = options
|
||||
new_beat_schedule[task_name] = new_task
|
||||
|
||||
if self._should_update_schedule(current_schedule, new_beat_schedule):
|
||||
logger.info(
|
||||
"Schedule update required",
|
||||
extra={
|
||||
"new_tasks": len(new_beat_schedule),
|
||||
"current_tasks": len(current_schedule),
|
||||
},
|
||||
)
|
||||
|
||||
# Create schedule entries
|
||||
entries = {}
|
||||
for name, entry in new_beat_schedule.items():
|
||||
entries[name] = self.Entry(
|
||||
name=name,
|
||||
app=self.app,
|
||||
task=entry["task"],
|
||||
schedule=entry["schedule"],
|
||||
options=entry.get("options", {}),
|
||||
kwargs=entry.get("kwargs", {}),
|
||||
)
|
||||
|
||||
# Update the schedule using the scheduler's methods
|
||||
self.schedule.clear()
|
||||
self.schedule.update(entries)
|
||||
|
||||
# Ensure changes are persisted
|
||||
self.sync()
|
||||
|
||||
logger.info("Schedule update completed successfully")
|
||||
else:
|
||||
logger.info("Schedule is up to date, no changes needed")
|
||||
|
||||
except (AttributeError, KeyError):
|
||||
logger.exception("Failed to process task configuration")
|
||||
except Exception:
|
||||
logger.exception("Unexpected error updating tenant tasks")
|
||||
|
||||
def _should_update_schedule(
|
||||
self, current_schedule: dict, new_schedule: dict
|
||||
) -> bool:
|
||||
"""Compare schedules to determine if an update is needed."""
|
||||
logger.debug("Comparing current and new schedules")
|
||||
current_tasks = set(name for name, _ in current_schedule)
|
||||
new_tasks = set(new_schedule.keys())
|
||||
needs_update = current_tasks != new_tasks
|
||||
logger.debug(f"Schedule update needed: {needs_update}")
|
||||
return needs_update
|
||||
|
||||
|
||||
@beat_init.connect
|
||||
def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("beat_init signal received.")
|
||||
|
||||
# Celery beat shouldn't touch the db at all. But just setting a low minimum here.
|
||||
# celery beat shouldn't touch the db at all. But just setting a low minimum here.
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=2, max_overflow=0)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
|
||||
|
||||
@@ -169,4 +35,68 @@ def on_setup_logging(
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
celery_app.conf.beat_scheduler = DynamicTenantScheduler
|
||||
#####
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": "check_for_vespa_sync_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": "check_for_connector_deletion_task",
|
||||
"schedule": timedelta(seconds=60),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": "check_for_indexing",
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": "check_for_pruning",
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": "kombu_message_cleanup_task",
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": "monitor_vespa_sync",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# Build the celery beat schedule dynamically
|
||||
beat_schedule = {}
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
for task in tasks_to_schedule:
|
||||
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
|
||||
beat_schedule[task_name] = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"options": task["options"],
|
||||
"kwargs": {"tenant_id": tenant_id}, # Must pass tenant_id as an argument
|
||||
}
|
||||
|
||||
# Include any existing beat schedules
|
||||
existing_beat_schedule = celery_app.conf.beat_schedule or {}
|
||||
beat_schedule.update(existing_beat_schedule)
|
||||
|
||||
# Update the Celery app configuration once
|
||||
celery_app.conf.beat_schedule = beat_schedule
|
||||
|
||||
@@ -13,7 +13,6 @@ import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -61,13 +60,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=4, max_overflow=12)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from celery import signals
|
||||
from celery import Task
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_process_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
@@ -14,7 +13,6 @@ import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -62,13 +60,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@@ -82,11 +74,6 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None:
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
|
||||
@@ -13,7 +13,6 @@ import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -60,13 +59,8 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@@ -91,6 +85,5 @@ celery_app.autodiscover_tasks(
|
||||
[
|
||||
"danswer.background.celery.tasks.shared",
|
||||
"danswer.background.celery.tasks.vespa",
|
||||
"danswer.background.celery.tasks.connector_deletion",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -13,21 +13,21 @@ from celery.signals import worker_shutdown
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisConnectorStop
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndex
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from danswer.redis.redis_connector_stop import RedisConnectorStop
|
||||
from danswer.redis.redis_document_set import RedisDocumentSet
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.redis.redis_usergroup import RedisUserGroup
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -75,64 +75,95 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
sender.primary_worker_locks = {}
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
# by the primary worker. This is unnecessary in the multi tenant scenario
|
||||
r = get_redis_client(tenant_id=None)
|
||||
# by the primary worker
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
for tenant_id in tenant_ids:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# For the moment, we're assuming that we are the only primary worker
|
||||
# that should be running.
|
||||
# TODO: maybe check for or clean up another zombie primary worker if we detect it
|
||||
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
|
||||
# For the moment, we're assuming that we are the only primary worker
|
||||
# that should be running.
|
||||
# TODO: maybe check for or clean up another zombie primary worker if we detect it
|
||||
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
|
||||
|
||||
# this process wide lock is taken to help other workers start up in order.
|
||||
# it is planned to use this lock to enforce singleton behavior on the primary
|
||||
# worker, since the primary worker does redis cleanup on startup, but this isn't
|
||||
# implemented yet.
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
# this process wide lock is taken to help other workers start up in order.
|
||||
# it is planned to use this lock to enforce singleton behavior on the primary
|
||||
# worker, since the primary worker does redis cleanup on startup, but this isn't
|
||||
# implemented yet.
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
logger.info("Primary worker lock: Acquire starting.")
|
||||
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
|
||||
if acquired:
|
||||
logger.info("Primary worker lock: Acquire succeeded.")
|
||||
else:
|
||||
logger.error("Primary worker lock: Acquire failed!")
|
||||
raise WorkerShutdown("Primary worker lock could not be acquired!")
|
||||
logger.info("Primary worker lock: Acquire starting.")
|
||||
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
|
||||
if acquired:
|
||||
logger.info("Primary worker lock: Acquire succeeded.")
|
||||
else:
|
||||
logger.error("Primary worker lock: Acquire failed!")
|
||||
raise WorkerShutdown("Primary worker lock could not be acquired!")
|
||||
|
||||
# tacking on our own user data to the sender
|
||||
sender.primary_worker_lock = lock
|
||||
# tacking on our own user data to the sender
|
||||
sender.primary_worker_locks[tenant_id] = lock
|
||||
|
||||
# As currently designed, when this worker starts as "primary", we reinitialize redis
|
||||
# to a clean state (for our purposes, anyway)
|
||||
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
|
||||
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
# As currently designed, when this worker starts as "primary", we reinitialize redis
|
||||
# to a clean state (for our purposes, anyway)
|
||||
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
|
||||
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||
|
||||
RedisDocumentSet.reset_all(r)
|
||||
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
RedisUserGroup.reset_all(r)
|
||||
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
RedisConnectorDelete.reset_all(r)
|
||||
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
RedisConnectorPrune.reset_all(r)
|
||||
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
RedisConnectorIndex.reset_all(r)
|
||||
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
RedisConnectorStop.reset_all(r)
|
||||
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
@@ -185,36 +216,52 @@ class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
if not celery_is_worker_primary(worker):
|
||||
return
|
||||
|
||||
if not hasattr(worker, "primary_worker_lock"):
|
||||
if not hasattr(worker, "primary_worker_locks"):
|
||||
return
|
||||
|
||||
lock = worker.primary_worker_lock
|
||||
# Retrieve all tenant IDs
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
r = get_redis_client(tenant_id=None)
|
||||
for tenant_id in tenant_ids:
|
||||
lock = worker.primary_worker_locks.get(tenant_id)
|
||||
if not lock:
|
||||
continue # Skip if no lock for this tenant
|
||||
|
||||
if lock.owned():
|
||||
task_logger.debug("Reacquiring primary worker lock.")
|
||||
lock.reacquire()
|
||||
else:
|
||||
task_logger.warning(
|
||||
"Full acquisition of primary worker lock. "
|
||||
"Reasons could be worker restart or lock expiration."
|
||||
)
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
task_logger.info("Primary worker lock: Acquire starting.")
|
||||
acquired = lock.acquire(
|
||||
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
|
||||
)
|
||||
if acquired:
|
||||
task_logger.info("Primary worker lock: Acquire succeeded.")
|
||||
worker.primary_worker_lock = lock
|
||||
if lock.owned():
|
||||
task_logger.debug(
|
||||
f"Reacquiring primary worker lock for tenant {tenant_id}."
|
||||
)
|
||||
lock.reacquire()
|
||||
else:
|
||||
task_logger.error("Primary worker lock: Acquire failed!")
|
||||
raise TimeoutError("Primary worker lock could not be acquired!")
|
||||
task_logger.warning(
|
||||
f"Full acquisition of primary worker lock for tenant {tenant_id}. "
|
||||
"Reasons could be worker restart or lock expiration."
|
||||
)
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Primary worker lock for tenant {tenant_id}: Acquire starting."
|
||||
)
|
||||
acquired = lock.acquire(
|
||||
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
|
||||
)
|
||||
if acquired:
|
||||
task_logger.info(
|
||||
f"Primary worker lock for tenant {tenant_id}: Acquire succeeded."
|
||||
)
|
||||
worker.primary_worker_locks[tenant_id] = lock
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Primary worker lock for tenant {tenant_id}: Acquire failed!"
|
||||
)
|
||||
raise TimeoutError(
|
||||
f"Primary worker lock for tenant {tenant_id} could not be acquired!"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Periodic task failed.")
|
||||
|
||||
@@ -1,10 +1,568 @@
|
||||
# These are helper objects for tracking the keys we need to write in redis
|
||||
import time
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.configs.base import CELERY_SEPARATOR
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import construct_document_select_for_connector_credential_pair
|
||||
from danswer.db.document import (
|
||||
construct_document_select_for_connector_credential_pair_by_needs_sync,
|
||||
)
|
||||
from danswer.db.document_set import construct_document_select_by_docset
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
class RedisObjectHelper(ABC):
|
||||
PREFIX = "base"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: str):
|
||||
self._id: str = id
|
||||
|
||||
@property
|
||||
def task_id_prefix(self) -> str:
|
||||
return f"{self.PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def fence_key(self) -> str:
|
||||
# example: documentset_fence_1
|
||||
return f"{self.FENCE_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def taskset_key(self) -> str:
|
||||
# example: documentset_taskset_1
|
||||
return f"{self.TASKSET_PREFIX}_{self._id}"
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_fence_key(key: str) -> str | None:
|
||||
"""
|
||||
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
|
||||
|
||||
Args:
|
||||
key (str): The fence key string.
|
||||
|
||||
Returns:
|
||||
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
|
||||
"""
|
||||
parts = key.split("_")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
object_id = parts[2]
|
||||
return object_id
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_task_id(task_id: str) -> str | None:
|
||||
"""
|
||||
Extracts the object ID from a task ID string.
|
||||
|
||||
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
|
||||
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
|
||||
- `objectid` is the ID you want to extract,
|
||||
- `suffix` is another arbitrary string (e.g., a UUID).
|
||||
|
||||
Example:
|
||||
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
|
||||
this method will return the string `"1"`.
|
||||
|
||||
Args:
|
||||
task_id (str): The task ID string from which to extract the object ID.
|
||||
|
||||
Returns:
|
||||
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
|
||||
"""
|
||||
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
|
||||
parts = task_id.split("_")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
object_id = parts[1]
|
||||
return object_id
|
||||
|
||||
@abstractmethod
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
pass
|
||||
|
||||
|
||||
class RedisDocumentSet(RedisObjectHelper):
|
||||
PREFIX = "documentset"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
stmt = construct_document_select_by_docset(int(self._id), current_only=False)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisUserGroup(RedisObjectHelper):
|
||||
PREFIX = "usergroup"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
|
||||
if not global_version.is_ee_version():
|
||||
return 0
|
||||
|
||||
try:
|
||||
construct_document_select_by_usergroup = fetch_versioned_implementation(
|
||||
"danswer.db.user_group",
|
||||
"construct_document_select_by_usergroup",
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
return 0
|
||||
|
||||
stmt = construct_document_select_by_usergroup(int(self._id))
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
"""This class is used to scan documents by cc_pair in the db and collect them into
|
||||
a unified set for syncing.
|
||||
|
||||
It differs from the other redis helpers in that the taskset used spans
|
||||
all connectors and is not per connector."""
|
||||
|
||||
PREFIX = "connectorsync"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
|
||||
@classmethod
|
||||
def get_fence_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.FENCE_PREFIX
|
||||
|
||||
@classmethod
|
||||
def get_taskset_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.TASKSET_PREFIX
|
||||
|
||||
@property
|
||||
def taskset_key(self) -> str:
|
||||
"""Notice that this is intentionally reusing the same taskset for all
|
||||
connector syncs"""
|
||||
# example: connector_taskset
|
||||
return f"{self.TASKSET_PREFIX}"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(
|
||||
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
|
||||
)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorDeletion(RedisObjectHelper):
|
||||
PREFIX = "connectordeletion"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Returns None if the cc_pair doesn't exist.
|
||||
Otherwise, returns an int with the number of generated tasks."""
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
stmt = construct_document_select_for_connector_credential_pair(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
kwargs=dict(
|
||||
document_id=doc.id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorPruning(RedisObjectHelper):
|
||||
"""Celery will kick off a long running generator task to crawl the connector and
|
||||
find any missing docs, which will each then get a new cleanup task. The progress of
|
||||
those tasks will then be monitored to completion.
|
||||
|
||||
Example rough happy path order:
|
||||
Check connectorpruning_fence_1
|
||||
Send generator task with id connectorpruning+generator_1_{uuid}
|
||||
|
||||
generator runs connector with callbacks that increment connectorpruning_generator_progress_1
|
||||
generator creates many subtasks with id connectorpruning+sub_1_{uuid}
|
||||
in taskset connectorpruning_taskset_1
|
||||
on completion, generator sets connectorpruning_generator_complete_1
|
||||
|
||||
celery postrun removes subtasks from taskset
|
||||
monitor beat task cleans up when taskset reaches 0 items
|
||||
"""
|
||||
|
||||
PREFIX = "connectorpruning"
|
||||
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire pruning process
|
||||
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
|
||||
|
||||
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
|
||||
SUBTASK_PREFIX = PREFIX + "+sub"
|
||||
|
||||
GENERATOR_PROGRESS_PREFIX = (
|
||||
PREFIX + "_generator_progress"
|
||||
) # a signal that contains generator progress
|
||||
GENERATOR_COMPLETE_PREFIX = (
|
||||
PREFIX + "_generator_complete"
|
||||
) # a signal that the generator has finished
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
self.documents_to_prune: set[str] = set()
|
||||
|
||||
@property
|
||||
def generator_task_id_prefix(self) -> str:
|
||||
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_progress_key(self) -> str:
|
||||
# example: connectorpruning_generator_progress_1
|
||||
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_complete_key(self) -> str:
|
||||
# example: connectorpruning_generator_complete_1
|
||||
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def subtask_id_prefix(self) -> str:
|
||||
return f"{self.SUBTASK_PREFIX}_{self._id}"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock | None,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
for doc_id in self.documents_to_prune:
|
||||
current_time = time.monotonic()
|
||||
if lock and current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.subtask_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
kwargs=dict(
|
||||
document_id=doc_id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
def is_pruning(self, redis_client: Redis) -> bool:
|
||||
"""A single example of a helper method being refactored into the redis helper"""
|
||||
if redis_client.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class RedisConnectorIndexing(RedisObjectHelper):
|
||||
"""Celery will kick off a long running indexing task to crawl the connector and
|
||||
find any new or updated docs docs, which will each then get a new sync task or be
|
||||
indexed inline.
|
||||
|
||||
ID should be a concatenation of cc_pair_id and search_setting_id, delimited by "/".
|
||||
e.g. "2/5"
|
||||
"""
|
||||
|
||||
PREFIX = "connectorindexing"
|
||||
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
|
||||
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
|
||||
|
||||
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
|
||||
SUBTASK_PREFIX = PREFIX + "+sub"
|
||||
|
||||
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
|
||||
GENERATOR_PROGRESS_PREFIX = (
|
||||
PREFIX + "_generator_progress"
|
||||
) # a signal that contains generator progress
|
||||
GENERATOR_COMPLETE_PREFIX = (
|
||||
PREFIX + "_generator_complete"
|
||||
) # a signal that the generator has finished
|
||||
|
||||
def __init__(self, cc_pair_id: int, search_settings_id: int) -> None:
|
||||
super().__init__(f"{cc_pair_id}/{search_settings_id}")
|
||||
|
||||
@property
|
||||
def generator_lock_key(self) -> str:
|
||||
return f"{self.GENERATOR_LOCK_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_task_id_prefix(self) -> str:
|
||||
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_progress_key(self) -> str:
|
||||
# example: connectorpruning_generator_progress_1
|
||||
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_complete_key(self) -> str:
|
||||
# example: connectorpruning_generator_complete_1
|
||||
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def subtask_id_prefix(self) -> str:
|
||||
return f"{self.SUBTASK_PREFIX}_{self._id}"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock | None,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
return None
|
||||
|
||||
def is_indexing(self, redis_client: Redis) -> bool:
|
||||
"""A single example of a helper method being refactored into the redis helper"""
|
||||
if redis_client.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class RedisConnectorStop(RedisObjectHelper):
|
||||
"""Used to signal any running tasks for a connector to stop. We should refactor
|
||||
connector related redis helpers into a single class.
|
||||
"""
|
||||
|
||||
PREFIX = "connectorstop"
|
||||
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
|
||||
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock | None,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
return None
|
||||
|
||||
|
||||
def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
@@ -17,7 +18,7 @@ from danswer.connectors.models import Document
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.enums import TaskStatus
|
||||
from danswer.db.models import TaskQueueState
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.documents.models import DeletionAttemptSnapshot
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -40,14 +41,14 @@ def _get_deletion_status(
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair.id)
|
||||
if not redis_connector.delete.fenced:
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
if not r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
return TaskQueueState(
|
||||
task_id="",
|
||||
task_name=redis_connector.delete.fence_key,
|
||||
status=TaskStatus.STARTED,
|
||||
task_id="", task_name=rcd.fence_key, status=TaskStatus.STARTED
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
|
||||
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": "check_for_vespa_sync_task",
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": "check_for_connector_deletion_task",
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": "check_for_indexing",
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": "check_for_pruning",
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": "kombu_message_cleanup_task",
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": "monitor_vespa_sync",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return tasks_to_schedule
|
||||
@@ -10,6 +10,13 @@ from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisConnectorStop
|
||||
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
|
||||
RedisConnectorDeletionFenceData,
|
||||
)
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
@@ -18,8 +25,6 @@ from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.search_settings import get_all_search_settings
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDeletionFenceData
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
|
||||
|
||||
@@ -57,7 +62,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
|
||||
# try running cleanup on the cc_pair_ids
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
rcs = RedisConnectorStop(cc_pair_id)
|
||||
try:
|
||||
try_generate_document_cc_pair_cleanup_tasks(
|
||||
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
|
||||
@@ -66,10 +71,10 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
|
||||
# this means we wanted to start deleting but dependent tasks were running
|
||||
# Leave a stop signal to clear indexing and pruning tasks more quickly
|
||||
task_logger.info(str(e))
|
||||
redis_connector.stop.set_fence(True)
|
||||
r.set(rcs.fence_key, cc_pair_id)
|
||||
else:
|
||||
# clear the stop signal if it exists ... no longer needed
|
||||
redis_connector.stop.set_fence(False)
|
||||
r.delete(rcs.fence_key)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
@@ -101,10 +106,10 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
if redis_connector.delete.fenced:
|
||||
if r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
# we need to load the state of the object inside the fence
|
||||
@@ -118,49 +123,47 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
return None
|
||||
|
||||
# set a basic fence to start
|
||||
fence_payload = RedisConnectorDeletionFenceData(
|
||||
fence_value = RedisConnectorDeletionFenceData(
|
||||
num_tasks=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
redis_connector.delete.set_fence(fence_payload)
|
||||
r.set(rcd.fence_key, fence_value.model_dump_json())
|
||||
|
||||
try:
|
||||
# do not proceed if connector indexing or connector pruning are running
|
||||
search_settings_list = get_all_search_settings(db_session)
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
if redis_connector_index.fenced:
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings.id)
|
||||
if r.get(rci.fence_key):
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (indexing in progress): "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
|
||||
if redis_connector.prune.fenced:
|
||||
rcp = RedisConnectorPruning(cc_pair_id)
|
||||
if r.get(rcp.fence_key):
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (pruning in progress): "
|
||||
f"cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
redis_connector.delete.taskset_clear()
|
||||
r.delete(rcd.taskset_key)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = redis_connector.delete.generate_tasks(
|
||||
app, db_session, lock_beat
|
||||
)
|
||||
tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id)
|
||||
if tasks_generated is None:
|
||||
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
|
||||
except TaskDependencyError:
|
||||
redis_connector.delete.set_fence(None)
|
||||
r.delete(rcd.fence_key)
|
||||
raise
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception")
|
||||
redis_connector.delete.set_fence(None)
|
||||
r.delete(rcd.fence_key)
|
||||
return None
|
||||
else:
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
@@ -175,7 +178,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
fence_payload.num_tasks = tasks_generated
|
||||
redis_connector.delete.set_fence(fence_payload)
|
||||
fence_value.num_tasks = tasks_generated
|
||||
r.set(rcd.fence_key, fence_value.model_dump_json())
|
||||
|
||||
return tasks_generated
|
||||
|
||||
@@ -2,9 +2,10 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
import sentry_sdk
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
@@ -13,6 +14,12 @@ from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorStop
|
||||
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
|
||||
RedisConnectorIndexingFenceData,
|
||||
)
|
||||
from danswer.background.indexing.job_client import SimpleJobClient
|
||||
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
|
||||
@@ -43,15 +50,12 @@ from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndexingFenceData
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -101,22 +105,19 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
return None
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
old_search_settings = check_index_swap(db_session=db_session)
|
||||
check_index_swap(db_session=db_session)
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
if current_search_settings.provider_type is None and not MULTI_TENANT:
|
||||
if old_search_settings:
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=current_search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
# only warm up if search settings were changed
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=current_search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
)
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -125,7 +126,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
@@ -138,10 +138,10 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
search_settings.append(secondary_search_settings)
|
||||
|
||||
for search_settings_instance in search_settings:
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings_instance.id
|
||||
rci = RedisConnectorIndexing(
|
||||
cc_pair_id, search_settings_instance.id
|
||||
)
|
||||
if redis_connector_index.fenced:
|
||||
if r.exists(rci.fence_key):
|
||||
continue
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
@@ -175,9 +175,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
)
|
||||
if attempt_id:
|
||||
task_logger.info(
|
||||
f"Indexing queued: index_attempt={attempt_id} "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings_instance.id} "
|
||||
f"Indexing queued: cc_pair={cc_pair.id} index_attempt={attempt_id}"
|
||||
)
|
||||
tasks_created += 1
|
||||
except SoftTimeLimitExceeded:
|
||||
@@ -306,15 +304,15 @@ def try_creating_indexing_task(
|
||||
return None
|
||||
|
||||
try:
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair.id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
rci = RedisConnectorIndexing(cc_pair.id, search_settings.id)
|
||||
|
||||
# skip if already indexing
|
||||
if redis_connector_index.fenced:
|
||||
if r.exists(rci.fence_key):
|
||||
return None
|
||||
|
||||
# skip indexing if the cc_pair is deleting
|
||||
if redis_connector.delete.fenced:
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
if r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
db_session.refresh(cc_pair)
|
||||
@@ -322,17 +320,19 @@ def try_creating_indexing_task(
|
||||
return None
|
||||
|
||||
# add a long running generator task to the queue
|
||||
redis_connector_index.generator_clear()
|
||||
r.delete(rci.generator_complete_key)
|
||||
r.delete(rci.taskset_key)
|
||||
|
||||
custom_task_id = f"{rci.generator_task_id_prefix}_{uuid4()}"
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorIndexingFenceData(
|
||||
fence_value = RedisConnectorIndexingFenceData(
|
||||
index_attempt_id=None,
|
||||
started=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
celery_task_id=None,
|
||||
)
|
||||
|
||||
redis_connector_index.set_fence(payload)
|
||||
r.set(rci.fence_key, fence_value.model_dump_json())
|
||||
|
||||
# create the index attempt for tracking purposes
|
||||
# code elsewhere checks for index attempts without an associated redis key
|
||||
@@ -345,8 +345,6 @@ def try_creating_indexing_task(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
custom_task_id = redis_connector_index.generate_generator_task_id()
|
||||
|
||||
result = celery_app.send_task(
|
||||
"connector_indexing_proxy_task",
|
||||
kwargs=dict(
|
||||
@@ -363,12 +361,11 @@ def try_creating_indexing_task(
|
||||
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
|
||||
|
||||
# now fill out the fence with the rest of the data
|
||||
payload.index_attempt_id = index_attempt_id
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector_index.set_fence(payload)
|
||||
|
||||
fence_value.index_attempt_id = index_attempt_id
|
||||
fence_value.celery_task_id = result.id
|
||||
r.set(rci.fence_key, fence_value.model_dump_json())
|
||||
except Exception:
|
||||
redis_connector_index.set_fence(payload)
|
||||
r.delete(rci.fence_key)
|
||||
task_logger.exception(
|
||||
f"Unexpected exception: "
|
||||
f"tenant={tenant_id} "
|
||||
@@ -391,12 +388,7 @@ def connector_indexing_proxy_task(
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
|
||||
task_logger.info(
|
||||
f"Indexing proxy - starting: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
client = SimpleJobClient()
|
||||
|
||||
job = client.submit(
|
||||
@@ -410,56 +402,29 @@ def connector_indexing_proxy_task(
|
||||
)
|
||||
|
||||
if not job:
|
||||
task_logger.info(
|
||||
f"Indexing proxy - spawn failed: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
while True:
|
||||
sleep(10)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if not job.done():
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if not job.done():
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
task_logger.error(
|
||||
f"Indexing proxy - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"error={job.exception()}"
|
||||
)
|
||||
if job.status == "error":
|
||||
logger.error(job.exception())
|
||||
|
||||
job.release()
|
||||
break
|
||||
job.release()
|
||||
break
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing proxy - finished: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@@ -481,97 +446,78 @@ def connector_indexing_task(
|
||||
|
||||
Returns None if the task did not run (possibly due to a conflict).
|
||||
Otherwise, returns an int >= 0 representing the number of indexed docs.
|
||||
|
||||
NOTE: if an exception is raised out of this task, the primary worker will detect
|
||||
that the task transitioned to a "READY" state but the generator_complete_key doesn't exist.
|
||||
This will cause the primary worker to abort the indexing attempt and clean up.
|
||||
"""
|
||||
|
||||
# Since connector_indexing_proxy_task spawns a new process using this function as
|
||||
# the entrypoint, we init Sentry here.
|
||||
if SENTRY_DSN:
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
traces_sample_rate=0.1,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task starting: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
attempt_found = False
|
||||
n_final_progress: int | None = None
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
attempt = None
|
||||
n_final_progress = 0
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if redis_connector.delete.fenced:
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
if r.exists(rcd.fence_key):
|
||||
raise RuntimeError(
|
||||
f"Indexing will not start because connector deletion is in progress: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={redis_connector.delete.fence_key}"
|
||||
f"fence={rcd.fence_key}"
|
||||
)
|
||||
|
||||
if redis_connector.stop.fenced:
|
||||
rcs = RedisConnectorStop(cc_pair_id)
|
||||
if r.exists(rcs.fence_key):
|
||||
raise RuntimeError(
|
||||
f"Indexing will not start because a connector stop signal was detected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={redis_connector.stop.fence_key}"
|
||||
f"fence={rcs.fence_key}"
|
||||
)
|
||||
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
|
||||
|
||||
while True:
|
||||
# wait for the fence to come up
|
||||
if not redis_connector_index.fenced:
|
||||
# read related data and evaluate/print task progress
|
||||
fence_value = cast(bytes, r.get(rci.fence_key))
|
||||
if fence_value is None:
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
|
||||
f"connector_indexing_task: fence_value not found: fence={rci.fence_key}"
|
||||
)
|
||||
|
||||
payload = redis_connector_index.payload
|
||||
if not payload:
|
||||
raise ValueError("connector_indexing_task: payload invalid or not found")
|
||||
try:
|
||||
fence_json = fence_value.decode("utf-8")
|
||||
fence_data = RedisConnectorIndexingFenceData.model_validate_json(
|
||||
cast(str, fence_json)
|
||||
)
|
||||
except ValueError:
|
||||
task_logger.exception(
|
||||
f"connector_indexing_task: fence_data not decodeable: fence={rci.fence_key}"
|
||||
)
|
||||
raise
|
||||
|
||||
if payload.index_attempt_id is None or payload.celery_task_id is None:
|
||||
logger.info(
|
||||
f"connector_indexing_task - Waiting for fence: fence={redis_connector_index.fence_key}"
|
||||
if fence_data.index_attempt_id is None or fence_data.celery_task_id is None:
|
||||
task_logger.info(
|
||||
f"connector_indexing_task - Waiting for fence: fence={rci.fence_key}"
|
||||
)
|
||||
sleep(1)
|
||||
continue
|
||||
|
||||
if payload.index_attempt_id != index_attempt_id:
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - id mismatch. Task may be left over from previous run.: "
|
||||
f"task_index_attempt={index_attempt_id} "
|
||||
f"payload_index_attempt={payload.index_attempt_id}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"connector_indexing_task - Fence found, continuing...: fence={redis_connector_index.fence_key}"
|
||||
task_logger.info(
|
||||
f"connector_indexing_task - Fence found, continuing...: fence={rci.fence_key}"
|
||||
)
|
||||
break
|
||||
|
||||
lock = r.lock(
|
||||
redis_connector_index.generator_lock_key,
|
||||
rci.generator_lock_key,
|
||||
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
task_logger.warning(
|
||||
f"Indexing task already running, exiting...: "
|
||||
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
|
||||
)
|
||||
# r.set(rci.generator_complete_key, HTTPStatus.CONFLICT.value)
|
||||
return None
|
||||
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector_index.set_fence(payload)
|
||||
fence_data.started = datetime.now(timezone.utc)
|
||||
r.set(rci.fence_key, fence_data.model_dump_json())
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -580,7 +526,6 @@ def connector_indexing_task(
|
||||
raise ValueError(
|
||||
f"Index attempt not found: index_attempt={index_attempt_id}"
|
||||
)
|
||||
attempt_found = True
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -600,52 +545,43 @@ def connector_indexing_task(
|
||||
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}"
|
||||
)
|
||||
|
||||
# define a callback class
|
||||
callback = RunIndexingCallback(
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector_index.generator_progress_key,
|
||||
lock,
|
||||
r,
|
||||
)
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task running entrypoint: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
# define a callback class
|
||||
callback = RunIndexingCallback(
|
||||
rcs.fence_key, rci.generator_progress_key, lock, r
|
||||
)
|
||||
|
||||
run_indexing_entrypoint(
|
||||
index_attempt_id,
|
||||
tenant_id,
|
||||
cc_pair_id,
|
||||
is_ee,
|
||||
callback=callback,
|
||||
)
|
||||
run_indexing_entrypoint(
|
||||
index_attempt_id,
|
||||
tenant_id,
|
||||
cc_pair_id,
|
||||
is_ee,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
# get back the total number of indexed docs and return it
|
||||
n_final_progress = redis_connector_index.get_progress()
|
||||
redis_connector_index.set_generator_complete(HTTPStatus.OK.value)
|
||||
# get back the total number of indexed docs and return it
|
||||
generator_progress_value = r.get(rci.generator_progress_key)
|
||||
if generator_progress_value is not None:
|
||||
try:
|
||||
n_final_progress = int(cast(int, generator_progress_value))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
r.set(rci.generator_complete_key, HTTPStatus.OK.value)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Indexing spawned task failed: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
if attempt_found:
|
||||
task_logger.exception(f"Indexing failed: cc_pair={cc_pair_id}")
|
||||
if attempt:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_failed(index_attempt_id, db_session, failure_reason=str(e))
|
||||
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
|
||||
|
||||
r.delete(rci.generator_lock_key)
|
||||
r.delete(rci.generator_progress_key)
|
||||
r.delete(rci.taskset_key)
|
||||
r.delete(rci.fence_key)
|
||||
raise e
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task finished: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
return n_final_progress
|
||||
|
||||
@@ -11,6 +11,9 @@ from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisConnectorStop
|
||||
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
@@ -30,7 +33,6 @@ from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import pruning_ctx
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -145,11 +147,8 @@ def try_creating_prune_generator_task(
|
||||
is used to trigger prunes immediately, e.g. via the web ui.
|
||||
"""
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair.id)
|
||||
|
||||
if not ALLOW_SIMULTANEOUS_PRUNING:
|
||||
count = redis_connector.prune.get_active_task_count()
|
||||
if count > 0:
|
||||
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
return None
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
@@ -166,10 +165,15 @@ def try_creating_prune_generator_task(
|
||||
return None
|
||||
|
||||
try:
|
||||
if redis_connector.prune.fenced: # skip pruning if already pruning
|
||||
rcp = RedisConnectorPruning(cc_pair.id)
|
||||
|
||||
# skip pruning if already pruning
|
||||
if r.exists(rcp.fence_key):
|
||||
return None
|
||||
|
||||
if redis_connector.delete.fenced: # skip pruning if the cc_pair is deleting
|
||||
# skip pruning if the cc_pair is deleting
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
if r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
db_session.refresh(cc_pair)
|
||||
@@ -177,10 +181,10 @@ def try_creating_prune_generator_task(
|
||||
return None
|
||||
|
||||
# add a long running generator task to the queue
|
||||
redis_connector.prune.generator_clear()
|
||||
redis_connector.prune.taskset_clear()
|
||||
r.delete(rcp.generator_complete_key)
|
||||
r.delete(rcp.taskset_key)
|
||||
|
||||
custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}"
|
||||
custom_task_id = f"{rcp.generator_task_id_prefix}_{uuid4()}"
|
||||
|
||||
celery_app.send_task(
|
||||
"connector_pruning_generator_task",
|
||||
@@ -196,7 +200,7 @@ def try_creating_prune_generator_task(
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
redis_connector.prune.set_fence(True)
|
||||
r.set(rcp.fence_key, 1)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}")
|
||||
return None
|
||||
@@ -231,12 +235,12 @@ def connector_pruning_generator_task(
|
||||
pruning_ctx_dict["request_id"] = self.request.id
|
||||
pruning_ctx.set(pruning_ctx_dict)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
rcp = RedisConnectorPruning(cc_pair_id)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
|
||||
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{rcp._id}",
|
||||
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
@@ -269,11 +273,10 @@ def connector_pruning_generator_task(
|
||||
cc_pair.credential,
|
||||
)
|
||||
|
||||
rcs = RedisConnectorStop(cc_pair_id)
|
||||
|
||||
callback = RunIndexingCallback(
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector.prune.generator_progress_key,
|
||||
lock,
|
||||
r,
|
||||
rcs.fence_key, rcp.generator_progress_key, lock, r
|
||||
)
|
||||
# a list of docs in the source
|
||||
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
|
||||
@@ -300,29 +303,31 @@ def connector_pruning_generator_task(
|
||||
f"doc_source={cc_pair.connector.source}"
|
||||
)
|
||||
|
||||
rcp.documents_to_prune = set(doc_ids_to_remove)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.prune.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
f"RedisConnectorPruning.generate_tasks starting. cc_pair={cc_pair.id}"
|
||||
)
|
||||
tasks_generated = redis_connector.prune.generate_tasks(
|
||||
set(doc_ids_to_remove), self.app, db_session, None
|
||||
tasks_generated = rcp.generate_tasks(
|
||||
self.app, db_session, r, None, tenant_id
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.prune.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
f"RedisConnectorPruning.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
redis_connector.prune.generator_complete = tasks_generated
|
||||
r.set(rcp.generator_complete_key, tasks_generated)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
|
||||
)
|
||||
|
||||
redis_connector.prune.generator_clear()
|
||||
redis_connector.prune.taskset_clear()
|
||||
redis_connector.prune.set_fence(False)
|
||||
r.delete(rcp.generator_progress_key)
|
||||
r.delete(rcp.taskset_key)
|
||||
r.delete(rcp.fence_key)
|
||||
raise e
|
||||
finally:
|
||||
if lock.owned():
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RedisConnectorDeletionFenceData(BaseModel):
|
||||
num_tasks: int | None
|
||||
submitted: datetime
|
||||
@@ -0,0 +1,10 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RedisConnectorIndexingFenceData(BaseModel):
|
||||
index_attempt_id: int | None
|
||||
started: datetime | None
|
||||
submitted: datetime
|
||||
celery_task_id: str | None
|
||||
@@ -19,6 +19,18 @@ from tenacity import RetryError
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import celery_get_queue_length
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
|
||||
RedisConnectorDeletionFenceData,
|
||||
)
|
||||
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
|
||||
RedisConnectorIndexingFenceData,
|
||||
)
|
||||
from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from danswer.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from danswer.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
@@ -55,14 +67,7 @@ from danswer.db.models import IndexAttempt
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import VespaDocumentFields
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndex
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from danswer.redis.redis_document_set import RedisDocumentSet
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.redis.redis_usergroup import RedisUserGroup
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import (
|
||||
@@ -187,7 +192,7 @@ def try_generate_stale_document_sync_tasks(
|
||||
total_tasks_generated = 0
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
|
||||
rc = RedisConnectorCredentialPair(cc_pair.id)
|
||||
tasks_generated = rc.generate_tasks(
|
||||
celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
@@ -223,10 +228,10 @@ def try_generate_document_set_sync_tasks(
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rds = RedisDocumentSet(tenant_id, document_set_id)
|
||||
rds = RedisDocumentSet(document_set_id)
|
||||
|
||||
# don't generate document set sync tasks if tasks are still pending
|
||||
if rds.fenced:
|
||||
if r.exists(rds.fence_key):
|
||||
return None
|
||||
|
||||
# don't generate sync tasks if we're up to date
|
||||
@@ -264,7 +269,7 @@ def try_generate_document_set_sync_tasks(
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
rds.set_fence(tasks_generated)
|
||||
r.set(rds.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
||||
|
||||
@@ -278,9 +283,10 @@ def try_generate_user_group_sync_tasks(
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rug = RedisUserGroup(tenant_id, usergroup_id)
|
||||
if rug.fenced:
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
rug = RedisUserGroup(usergroup_id)
|
||||
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
if r.exists(rug.fence_key):
|
||||
return None
|
||||
|
||||
# race condition with the monitor/cleanup function if we use a cached result!
|
||||
@@ -320,7 +326,7 @@ def try_generate_user_group_sync_tasks(
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
rug.set_fence(tasks_generated)
|
||||
r.set(rug.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
||||
|
||||
@@ -346,7 +352,7 @@ def monitor_connector_taskset(r: Redis) -> None:
|
||||
|
||||
|
||||
def monitor_document_set_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
|
||||
@@ -356,12 +362,16 @@ def monitor_document_set_taskset(
|
||||
|
||||
document_set_id = int(document_set_id_str)
|
||||
|
||||
rds = RedisDocumentSet(tenant_id, document_set_id)
|
||||
if not rds.fenced:
|
||||
rds = RedisDocumentSet(document_set_id)
|
||||
|
||||
fence_value = r.get(rds.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
initial_count = rds.payload
|
||||
if initial_count is None:
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rds.taskset_key))
|
||||
@@ -389,38 +399,48 @@ def monitor_document_set_taskset(
|
||||
f"Successfully synced document set: document_set={document_set_id}"
|
||||
)
|
||||
|
||||
rds.reset()
|
||||
r.delete(rds.taskset_key)
|
||||
r.delete(rds.fence_key)
|
||||
|
||||
|
||||
def monitor_connector_deletion_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis
|
||||
key_bytes: bytes, r: Redis, tenant_id: str | None
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
cc_pair_id_str = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
|
||||
fence_data = redis_connector.delete.payload
|
||||
if not fence_data:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - fence payload invalid: cc_pair={cc_pair_id}"
|
||||
# read related data and evaluate/print task progress
|
||||
fence_value = cast(bytes, r.get(rcd.fence_key))
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
fence_json = fence_value.decode("utf-8")
|
||||
fence_data = RedisConnectorDeletionFenceData.model_validate_json(
|
||||
cast(str, fence_json)
|
||||
)
|
||||
return
|
||||
except ValueError:
|
||||
task_logger.exception(
|
||||
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
|
||||
)
|
||||
raise
|
||||
|
||||
# the fence is setting up but isn't ready yet
|
||||
if fence_data.num_tasks is None:
|
||||
# the fence is setting up but isn't ready yet
|
||||
return
|
||||
|
||||
remaining = redis_connector.delete.get_remaining()
|
||||
count = cast(int, r.scard(rcd.taskset_key))
|
||||
task_logger.info(
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={fence_data.num_tasks}"
|
||||
)
|
||||
if remaining > 0:
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -504,15 +524,15 @@ def monitor_connector_deletion_taskset(
|
||||
f"docs_deleted={fence_data.num_tasks}"
|
||||
)
|
||||
|
||||
redis_connector.delete.taskset_clear()
|
||||
redis_connector.delete.set_fence(None)
|
||||
r.delete(rcd.taskset_key)
|
||||
r.delete(rcd.fence_key)
|
||||
|
||||
|
||||
def monitor_ccpair_pruning_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
cc_pair_id_str = RedisConnectorPruning.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
|
||||
@@ -521,37 +541,46 @@ def monitor_ccpair_pruning_taskset(
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if not redis_connector.prune.fenced:
|
||||
rcp = RedisConnectorPruning(cc_pair_id)
|
||||
|
||||
fence_value = r.get(rcp.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
initial = redis_connector.prune.generator_complete
|
||||
if initial is None:
|
||||
generator_value = r.get(rcp.generator_complete_key)
|
||||
if generator_value is None:
|
||||
return
|
||||
|
||||
remaining = redis_connector.prune.get_remaining()
|
||||
try:
|
||||
initial_count = int(cast(int, generator_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rcp.taskset_key))
|
||||
task_logger.info(
|
||||
f"Connector pruning progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
|
||||
f"Connector pruning progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if remaining > 0:
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
|
||||
task_logger.info(
|
||||
f"Successfully pruned connector credential pair. cc_pair={cc_pair_id}"
|
||||
f"Successfully pruned connector credential pair. cc_pair_id={cc_pair_id}"
|
||||
)
|
||||
|
||||
redis_connector.prune.taskset_clear()
|
||||
redis_connector.prune.generator_clear()
|
||||
redis_connector.prune.set_fence(False)
|
||||
r.delete(rcp.taskset_key)
|
||||
r.delete(rcp.generator_progress_key)
|
||||
r.delete(rcp.generator_complete_key)
|
||||
r.delete(rcp.fence_key)
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
composite_id = RedisConnectorIndexing.get_id_from_fence_key(fence_key)
|
||||
if composite_id is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_indexing_taskset: could not parse composite_id from {fence_key}"
|
||||
@@ -566,37 +595,53 @@ def monitor_ccpair_indexing_taskset(
|
||||
cc_pair_id = int(parts[0])
|
||||
search_settings_id = int(parts[1])
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
if not redis_connector_index.fenced:
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
|
||||
|
||||
# read related data and evaluate/print task progress
|
||||
fence_value = cast(bytes, r.get(rci.fence_key))
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
payload = redis_connector_index.payload
|
||||
if not payload:
|
||||
return
|
||||
|
||||
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
|
||||
|
||||
progress = redis_connector_index.get_progress()
|
||||
if progress is not None:
|
||||
task_logger.info(
|
||||
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
|
||||
f"search_settings_id={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
try:
|
||||
fence_json = fence_value.decode("utf-8")
|
||||
fence_data = RedisConnectorIndexingFenceData.model_validate_json(
|
||||
cast(str, fence_json)
|
||||
)
|
||||
except ValueError:
|
||||
task_logger.exception(
|
||||
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
|
||||
)
|
||||
raise
|
||||
|
||||
if payload.index_attempt_id is None or payload.celery_task_id is None:
|
||||
elapsed_submitted = datetime.now(timezone.utc) - fence_data.submitted
|
||||
|
||||
generator_progress_value = r.get(rci.generator_progress_key)
|
||||
if generator_progress_value is not None:
|
||||
try:
|
||||
progress_count = int(cast(int, generator_progress_value))
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
|
||||
f"search_settings_id={search_settings_id} "
|
||||
f"progress={progress_count} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
except ValueError:
|
||||
task_logger.error(
|
||||
"monitor_ccpair_indexing_taskset: generator_progress_value is not an integer."
|
||||
)
|
||||
|
||||
if fence_data.index_attempt_id is None or fence_data.celery_task_id is None:
|
||||
# the task is still setting up
|
||||
return
|
||||
|
||||
# Read result state BEFORE generator_complete_key to avoid a race condition
|
||||
# never use any blocking methods on the result from inside a task!
|
||||
result: AsyncResult = AsyncResult(payload.celery_task_id)
|
||||
result: AsyncResult = AsyncResult(fence_data.celery_task_id)
|
||||
result_state = result.state
|
||||
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int is None:
|
||||
generator_complete_value = r.get(rci.generator_complete_key)
|
||||
if generator_complete_value is None:
|
||||
if result_state in READY_STATES:
|
||||
# IF the task state is READY, THEN generator_complete should be set
|
||||
# if it isn't, then the worker crashed
|
||||
@@ -607,18 +652,30 @@ def monitor_ccpair_indexing_taskset(
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
|
||||
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
|
||||
index_attempt = get_index_attempt(db_session, fence_data.index_attempt_id)
|
||||
if index_attempt:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=payload.index_attempt_id,
|
||||
index_attempt=index_attempt,
|
||||
db_session=db_session,
|
||||
failure_reason="Connector indexing aborted or exceptioned.",
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
r.delete(rci.generator_lock_key)
|
||||
r.delete(rci.taskset_key)
|
||||
r.delete(rci.generator_progress_key)
|
||||
r.delete(rci.generator_complete_key)
|
||||
r.delete(rci.fence_key)
|
||||
return
|
||||
|
||||
status_enum = HTTPStatus(status_int)
|
||||
status_enum = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
try:
|
||||
status_value = int(cast(int, generator_complete_value))
|
||||
status_enum = HTTPStatus(status_value)
|
||||
except ValueError:
|
||||
task_logger.error(
|
||||
f"monitor_ccpair_indexing_taskset: "
|
||||
f"generator_complete_value=f{generator_complete_value} could not be parsed."
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing finished: cc_pair_id={cc_pair_id} "
|
||||
@@ -627,7 +684,11 @@ def monitor_ccpair_indexing_taskset(
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
r.delete(rci.generator_lock_key)
|
||||
r.delete(rci.taskset_key)
|
||||
r.delete(rci.generator_progress_key)
|
||||
r.delete(rci.generator_complete_key)
|
||||
r.delete(rci.fence_key)
|
||||
|
||||
|
||||
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
|
||||
@@ -639,7 +700,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
|
||||
do anything too expensive in this function!
|
||||
|
||||
Returns True if the task actually did work, False if it exited early to prevent overlap
|
||||
Returns True if the task actually did work, False
|
||||
"""
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
@@ -690,33 +751,27 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
|
||||
for a in attempts:
|
||||
# if attempts exist in the db but we don't detect them in redis, mark them as failed
|
||||
fence_key = RedisConnectorIndex.fence_key_with_ids(
|
||||
rci = RedisConnectorIndexing(
|
||||
a.connector_credential_pair_id, a.search_settings_id
|
||||
)
|
||||
if not r.exists(fence_key):
|
||||
failure_reason = (
|
||||
f"Unknown index attempt. Might be left over from a process restart: "
|
||||
f"index_attempt={a.id} "
|
||||
f"cc_pair={a.connector_credential_pair_id} "
|
||||
f"search_settings={a.search_settings_id}"
|
||||
)
|
||||
task_logger.warning(failure_reason)
|
||||
mark_attempt_failed(a.id, db_session, failure_reason=failure_reason)
|
||||
failure_reason = f"Unknown index attempt {a.id}. Might be left over from a process restart."
|
||||
if not r.exists(rci.fence_key):
|
||||
mark_attempt_failed(a, db_session, failure_reason=failure_reason)
|
||||
|
||||
lock_beat.reacquire()
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"):
|
||||
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
monitor_connector_deletion_taskset(key_bytes, r, tenant_id)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
monitor_document_set_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
@@ -727,19 +782,19 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
noop_fallback,
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
monitor_usergroup_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
|
||||
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
|
||||
monitor_ccpair_indexing_taskset(key_bytes, r, db_session)
|
||||
|
||||
# uncomment for debugging if needed
|
||||
# r_celery = celery_app.broker_connection().channel().client
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Factory stub for running celery worker / celery beat."""
|
||||
from danswer.background.celery.apps.beat import celery_app
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app = celery_app
|
||||
app = fetch_versioned_implementation(
|
||||
"danswer.background.celery.apps.beat", "celery_app"
|
||||
)
|
||||
|
||||
@@ -29,26 +29,18 @@ JobStatusType = (
|
||||
def _initializer(
|
||||
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Initialize the child process with a fresh SQLAlchemy Engine.
|
||||
"""Ensure the parent proc's database connections are not touched
|
||||
in the new connection pool
|
||||
|
||||
Based on SQLAlchemy's recommendations to handle multiprocessing:
|
||||
Based on the recommended approach in the SQLAlchemy docs found:
|
||||
https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
logger.info("Initializing spawned worker child process.")
|
||||
|
||||
# Reset the engine in the child process
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
# Optionally set a custom app name for database logging purposes
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
|
||||
|
||||
# Initialize a new engine with desired parameters
|
||||
SqlEngine.init_engine(pool_size=4, max_overflow=12, pool_recycle=60)
|
||||
|
||||
# Proceed with executing the target function
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -118,13 +118,7 @@ def _run_indexing(
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
if index_attempt.search_settings is None:
|
||||
raise ValueError(
|
||||
"Search settings must be set for indexing. This should not be possible."
|
||||
)
|
||||
|
||||
search_settings = index_attempt.search_settings
|
||||
|
||||
index_name = search_settings.index_name
|
||||
|
||||
# Only update cc-pair status for primary index jobs
|
||||
@@ -337,7 +331,7 @@ def _run_indexing(
|
||||
or index_attempt.status != IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt.id,
|
||||
index_attempt,
|
||||
db_session,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
@@ -372,7 +366,7 @@ def _run_indexing(
|
||||
and index_attempt_md.num_exceptions >= batch_num
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt.id,
|
||||
index_attempt,
|
||||
db_session,
|
||||
failure_reason="All batches exceptioned.",
|
||||
)
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
def name_sync_external_doc_permissions_task(
|
||||
cc_pair_id: int, tenant_id: str | None = None
|
||||
) -> str:
|
||||
return f"sync_external_doc_permissions_task__{cc_pair_id}"
|
||||
@@ -10,7 +10,7 @@ from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SearchResponse
|
||||
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
from danswer.tools.custom.base_tool_types import ToolResultType
|
||||
|
||||
|
||||
class LlmDoc(BaseModel):
|
||||
@@ -156,7 +156,7 @@ class QAResponse(SearchResponse, DanswerAnswer):
|
||||
error_msg: str | None = None
|
||||
|
||||
|
||||
class FileChatDisplay(BaseModel):
|
||||
class ImageGenerationDisplay(BaseModel):
|
||||
file_ids: list[str]
|
||||
|
||||
|
||||
@@ -170,7 +170,7 @@ AnswerQuestionPossibleReturn = (
|
||||
| DanswerQuotes
|
||||
| CitationInfo
|
||||
| DanswerContexts
|
||||
| FileChatDisplay
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| StreamingError
|
||||
| StreamStopInfo
|
||||
|
||||
@@ -42,14 +42,18 @@ personas:
|
||||
display_priority: 1
|
||||
is_visible: true
|
||||
starter_messages:
|
||||
- name: "Give me an overview of what's here"
|
||||
message: "Sample some documents and tell me what you find."
|
||||
- name: "Use AI to solve a work related problem"
|
||||
message: "Ask me what problem I would like to solve, then search the knowledge base to help me find a solution."
|
||||
- name: "Find updates on a topic of interest"
|
||||
message: "Once I provide a topic, retrieve related documents and tell me when there was last activity on the topic if available."
|
||||
- name: "Surface contradictions"
|
||||
message: "Have me choose a subject. Once I have provided it, check against the knowledge base and point out any inconsistencies. For all your following responses, focus on identifying contradictions."
|
||||
- name: "General Information"
|
||||
description: "Ask about available information"
|
||||
message: "Hello! I'm interested in learning more about the information available here. Could you give me an overview of the types of data or documents that might be accessible?"
|
||||
- name: "Specific Topic Search"
|
||||
description: "Search for specific information"
|
||||
message: "Hi! I'd like to learn more about a specific topic. Could you help me find relevant documents and information?"
|
||||
- name: "Recent Updates"
|
||||
description: "Inquire about latest additions"
|
||||
message: "Hello! I'm curious about any recent updates or additions to the knowledge base. Can you tell me what new information has been added lately?"
|
||||
- name: "Cross-referencing Information"
|
||||
description: "Connect information from different sources"
|
||||
message: "Hi! I'm working on a project that requires connecting information from multiple sources. How can I effectively cross-reference data across different documents or categories?"
|
||||
|
||||
- id: 1
|
||||
name: "General"
|
||||
@@ -67,14 +71,18 @@ personas:
|
||||
display_priority: 0
|
||||
is_visible: true
|
||||
starter_messages:
|
||||
- name: "Summarize a document"
|
||||
message: "If I have provided a document please summarize it for me. If not, please ask me to upload a document either by dragging it into the input bar or clicking the +file icon."
|
||||
- name: "Help me with coding"
|
||||
message: 'Write me a "Hello World" script in 5 random languages to show off the functionality.'
|
||||
- name: "Draft a professional email"
|
||||
message: "Help me craft a professional email. Let's establish the context and the anticipated outcomes of the email before proposing a draft."
|
||||
- name: "Learn something new"
|
||||
message: "What is the difference between a Gantt chart, a Burndown chart and a Kanban board?"
|
||||
- name: "Open Discussion"
|
||||
description: "Start an open-ended conversation"
|
||||
message: "Hi! Can you help me write a professional email?"
|
||||
- name: "Problem Solving"
|
||||
description: "Get help with a challenge"
|
||||
message: "Hello! I need help managing my daily tasks better. Do you have any simple tips?"
|
||||
- name: "Learn Something New"
|
||||
description: "Explore a new topic"
|
||||
message: "Hi! Could you explain what project management is in simple terms?"
|
||||
- name: "Creative Brainstorming"
|
||||
description: "Generate creative ideas"
|
||||
message: "Hello! I need to brainstorm some team building activities. Do you have any fun suggestions?"
|
||||
|
||||
- id: 2
|
||||
name: "Paraphrase"
|
||||
@@ -93,12 +101,16 @@ personas:
|
||||
is_visible: false
|
||||
starter_messages:
|
||||
- name: "Document Search"
|
||||
description: "Find exact information"
|
||||
message: "Hi! Could you help me find information about our team structure and reporting lines from our internal documents?"
|
||||
- name: "Process Verification"
|
||||
description: "Find exact quotes"
|
||||
message: "Hello! I need to understand our project approval process. Could you find the exact steps from our documentation?"
|
||||
- name: "Technical Documentation"
|
||||
description: "Search technical details"
|
||||
message: "Hi there! I'm looking for information about our deployment procedures. Can you find the specific steps from our technical guides?"
|
||||
- name: "Policy Reference"
|
||||
description: "Check official policies"
|
||||
message: "Hello! Could you help me find our official guidelines about client communication? I need the exact wording from our documentation."
|
||||
|
||||
- id: 3
|
||||
@@ -118,11 +130,15 @@ personas:
|
||||
display_priority: 3
|
||||
is_visible: true
|
||||
starter_messages:
|
||||
- name: "Create visuals for a presentation"
|
||||
message: "Generate someone presenting a graph which clearly demonstrates an upwards trajectory."
|
||||
- name: "Find inspiration for a marketing campaign"
|
||||
message: "Generate an image of two happy individuals sipping on a soda drink in a glass bottle."
|
||||
- name: "Visualize a product design"
|
||||
message: "I want to add a search bar to my Iphone app. Generate me generic examples of how other apps implement this."
|
||||
- name: "Generate a humorous image response"
|
||||
message: "My teammate just made a silly mistake and I want to respond with a facepalm. Can you generate me one?"
|
||||
- name: "Landscape"
|
||||
description: "Generate a landscape image"
|
||||
message: "Create an image of a serene mountain lake at sunset, with snow-capped peaks reflected in the calm water and a small wooden cabin on the shore."
|
||||
- name: "Character"
|
||||
description: "Generate a character image"
|
||||
message: "Generate an image of a futuristic robot with glowing blue eyes, sleek metallic body, and intricate circuitry visible through transparent panels on its chest and arms."
|
||||
- name: "Abstract"
|
||||
description: "Create an abstract image"
|
||||
message: "Create an abstract image representing the concept of time, using swirling clock hands, fragmented hourglasses, and streaks of light to convey the passage of moments and eras."
|
||||
- name: "Urban Scene"
|
||||
description: "Generate an urban landscape"
|
||||
message: "Generate an image of a bustling futuristic cityscape at night, with towering skyscrapers, flying vehicles, holographic advertisements, and a mix of neon and bioluminescent lighting."
|
||||
|
||||
@@ -11,18 +11,23 @@ from danswer.chat.models import AllCitations
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import CustomToolResponse
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import FileChatDisplay
|
||||
from danswer.chat.models import FinalUsedContextDocsResponse
|
||||
from danswer.chat.models import ImageGenerationDisplay
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import MessageResponseIDInfo
|
||||
from danswer.chat.models import MessageSpecificCitations
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
|
||||
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.db.chat import attach_files_to_chat_message
|
||||
from danswer.db.chat import create_db_search_doc
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
@@ -35,6 +40,7 @@ from danswer.db.chat import reserve_message_id
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.models import SearchDoc as DbSearchDoc
|
||||
from danswer.db.models import ToolCall
|
||||
from danswer.db.models import User
|
||||
@@ -54,13 +60,14 @@ from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import litellm_exception_to_error_msg
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.search.retrieval.search_runner import inference_sections_from_ids
|
||||
from danswer.search.utils import chunks_or_sections_to_search_docs
|
||||
from danswer.search.utils import dedupe_documents
|
||||
@@ -69,48 +76,36 @@ from danswer.search.utils import relevant_sections_to_indices
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.built_in_tools import get_built_in_tool_by_id
|
||||
from danswer.tools.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
from danswer.tools.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool_constructor import construct_tools
|
||||
from danswer.tools.tool_constructor import CustomToolConfig
|
||||
from danswer.tools.tool_constructor import ImageGenerationToolConfig
|
||||
from danswer.tools.tool_constructor import InternetSearchToolConfig
|
||||
from danswer.tools.tool_constructor import SearchToolConfig
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
CUSTOM_TOOL_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
IMAGE_GENERATION_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationResponse,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
from danswer.tools.internet_search.internet_search_tool import (
|
||||
INTERNET_SEARCH_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
from danswer.tools.internet_search.internet_search_tool import (
|
||||
internet_search_response_to_search_docs,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchResponse,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchTool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SECTION_RELEVANCE_LIST_ID,
|
||||
)
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
|
||||
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
||||
from danswer.tools.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.tools.utils import compute_all_tool_tokens
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.headers import header_dict_to_header_list
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
|
||||
@@ -261,11 +256,10 @@ ChatPacket = (
|
||||
| DanswerAnswerPiece
|
||||
| AllCitations
|
||||
| CitationInfo
|
||||
| FileChatDisplay
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| MessageSpecificCitations
|
||||
| MessageResponseIDInfo
|
||||
| StreamStopInfo
|
||||
)
|
||||
ChatPacketStream = Iterator[ChatPacket]
|
||||
|
||||
@@ -281,6 +275,7 @@ def stream_chat_message_objects(
|
||||
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||
# if specified, uses the last user message and does not create a new user message based
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
@@ -292,9 +287,6 @@ def stream_chat_message_objects(
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
"""
|
||||
use_existing_user_message = new_msg_req.use_existing_user_message
|
||||
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
|
||||
|
||||
# Currently surrounding context is not supported for chat
|
||||
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
|
||||
new_msg_req.chunks_above = 0
|
||||
@@ -416,20 +408,12 @@ def stream_chat_message_objects(
|
||||
final_msg, history_msgs = create_chat_chain(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
if existing_assistant_message_id is None:
|
||||
if final_msg.message_type != MessageType.USER:
|
||||
raise RuntimeError(
|
||||
"The last message was not a user message. Cannot call "
|
||||
"`stream_chat_message_objects` with `is_regenerate=True` "
|
||||
"when the last message is not a user message."
|
||||
)
|
||||
else:
|
||||
if final_msg.id != existing_assistant_message_id:
|
||||
raise RuntimeError(
|
||||
"The last message was not the existing assistant message. "
|
||||
f"Final message id: {final_msg.id}, "
|
||||
f"existing assistant message id: {existing_assistant_message_id}"
|
||||
)
|
||||
if final_msg.message_type != MessageType.USER:
|
||||
raise RuntimeError(
|
||||
"The last message was not a user message. Cannot call "
|
||||
"`stream_chat_message_objects` with `is_regenerate=True` "
|
||||
"when the last message is not a user message."
|
||||
)
|
||||
|
||||
# Disable Query Rephrasing for the first message
|
||||
# This leads to a better first response since the LLM rephrasing the question
|
||||
@@ -500,19 +484,13 @@ def stream_chat_message_objects(
|
||||
),
|
||||
max_window_percentage=max_document_percentage,
|
||||
)
|
||||
|
||||
# we don't need to reserve a message id if we're using an existing assistant message
|
||||
reserved_message_id = (
|
||||
final_msg.id
|
||||
if existing_assistant_message_id is not None
|
||||
else reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=user_message.id
|
||||
if user_message is not None
|
||||
else parent_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
reserved_message_id = reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=user_message.id
|
||||
if user_message is not None
|
||||
else parent_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
yield MessageResponseIDInfo(
|
||||
user_message_id=user_message.id if user_message else None,
|
||||
@@ -527,13 +505,7 @@ def stream_chat_message_objects(
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
chat_session_id=chat_session_id,
|
||||
# if we're using an existing assistant message, then this will just be an
|
||||
# update operation, in which case the parent should be the parent of
|
||||
# the latest. If we're creating a new assistant message, then the parent
|
||||
# should be the latest message (latest user message)
|
||||
parent_message=(
|
||||
final_msg if existing_assistant_message_id is None else parent_message
|
||||
),
|
||||
parent_message=final_msg,
|
||||
prompt_id=prompt_id,
|
||||
overridden_model=overridden_model,
|
||||
# message=,
|
||||
@@ -545,7 +517,6 @@ def stream_chat_message_objects(
|
||||
# reference_docs=,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
reserved_message_id=reserved_message_id,
|
||||
)
|
||||
|
||||
if not final_msg.prompt:
|
||||
@@ -561,53 +532,148 @@ def stream_chat_message_objects(
|
||||
if not persona
|
||||
else PromptConfig.from_model(persona.prompts[0])
|
||||
)
|
||||
answer_style_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig(
|
||||
all_docs_useful=selected_db_search_docs is not None
|
||||
),
|
||||
document_pruning_config=document_pruning_config,
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
)
|
||||
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
prompt_config=prompt_config,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
retrieval_options=retrieval_options or RetrievalDetails(),
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
latest_query_files=latest_query_files,
|
||||
),
|
||||
internet_search_tool_config=InternetSearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
),
|
||||
image_generation_tool_config=ImageGenerationToolConfig(
|
||||
additional_headers=litellm_additional_headers,
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
),
|
||||
)
|
||||
# find out what tools to use
|
||||
search_tool: SearchTool | None = None
|
||||
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
|
||||
for db_tool_model in persona.tools:
|
||||
# handle in-code tools specially
|
||||
if db_tool_model.in_code_tool_id:
|
||||
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
|
||||
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
persona=persona,
|
||||
retrieval_options=retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
evaluation_type=LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
img_generation_llm_config: LLMConfig | None = None
|
||||
if (
|
||||
llm
|
||||
and llm.config.api_key
|
||||
and llm.config.model_provider == "openai"
|
||||
):
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider=llm.config.model_provider,
|
||||
model_name="dall-e-3",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=llm.config.api_key,
|
||||
api_base=llm.config.api_base,
|
||||
api_version=llm.config.api_version,
|
||||
)
|
||||
elif (
|
||||
llm.config.model_provider == "azure"
|
||||
and AZURE_DALLE_API_KEY is not None
|
||||
):
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider="azure",
|
||||
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=AZURE_DALLE_API_KEY,
|
||||
api_base=AZURE_DALLE_API_BASE,
|
||||
api_version=AZURE_DALLE_API_VERSION,
|
||||
)
|
||||
else:
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
openai_provider = next(
|
||||
iter(
|
||||
[
|
||||
llm_provider
|
||||
for llm_provider in llm_providers
|
||||
if llm_provider.provider == "openai"
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not openai_provider or not openai_provider.api_key:
|
||||
raise ValueError(
|
||||
"Image generation tool requires an OpenAI API key"
|
||||
)
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider=openai_provider.provider,
|
||||
model_name="dall-e-3",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=openai_provider.api_key,
|
||||
api_base=openai_provider.api_base,
|
||||
api_version=openai_provider.api_version,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
ImageGenerationTool(
|
||||
api_key=cast(str, img_generation_llm_config.api_key),
|
||||
api_base=img_generation_llm_config.api_base,
|
||||
api_version=img_generation_llm_config.api_version,
|
||||
additional_headers=litellm_additional_headers,
|
||||
model=img_generation_llm_config.model_name,
|
||||
)
|
||||
]
|
||||
elif tool_cls.__name__ == InternetSearchTool.__name__:
|
||||
bing_api_key = BING_API_KEY
|
||||
if not bing_api_key:
|
||||
raise ValueError(
|
||||
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
InternetSearchTool(api_key=bing_api_key)
|
||||
]
|
||||
|
||||
continue
|
||||
|
||||
# handle all custom tools
|
||||
if db_tool_model.openapi_schema:
|
||||
tool_dict[db_tool_model.id] = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema_and_headers(
|
||||
db_tool_model.openapi_schema,
|
||||
dynamic_schema_info=DynamicSchemaInfo(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
),
|
||||
custom_headers=(db_tool_model.custom_headers or [])
|
||||
+ (
|
||||
header_dict_to_header_list(
|
||||
custom_tool_additional_headers or {}
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
|
||||
tools, llm_tokenizer
|
||||
)
|
||||
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
|
||||
llm_provider, llm_model_name
|
||||
)
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
answer = Answer(
|
||||
is_connected=is_connected,
|
||||
question=final_msg.message,
|
||||
latest_query_files=latest_query_files,
|
||||
answer_style_config=answer_style_config,
|
||||
answer_style_config=AnswerStyleConfig(
|
||||
citation_config=CitationConfig(
|
||||
all_docs_useful=selected_db_search_docs is not None
|
||||
),
|
||||
document_pruning_config=document_pruning_config,
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
),
|
||||
prompt_config=prompt_config,
|
||||
llm=(
|
||||
llm
|
||||
@@ -675,6 +741,7 @@ def stream_chat_message_objects(
|
||||
yield LLMRelevanceFilterResponse(
|
||||
llm_selected_doc_indices=llm_indices
|
||||
)
|
||||
|
||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
yield FinalUsedContextDocsResponse(
|
||||
final_context_docs=packet.response
|
||||
@@ -692,7 +759,7 @@ def stream_chat_message_objects(
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
for file_id in file_ids
|
||||
]
|
||||
yield FileChatDisplay(
|
||||
yield ImageGenerationDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
@@ -706,32 +773,11 @@ def stream_chat_message_objects(
|
||||
yield qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
|
||||
if (
|
||||
custom_tool_response.response_type == "image"
|
||||
or custom_tool_response.response_type == "csv"
|
||||
):
|
||||
file_ids = custom_tool_response.tool_result.file_ids
|
||||
ai_message_files = [
|
||||
FileDescriptor(
|
||||
id=str(file_id),
|
||||
type=ChatFileType.IMAGE
|
||||
if custom_tool_response.response_type == "image"
|
||||
else ChatFileType.CSV,
|
||||
)
|
||||
for file_id in file_ids
|
||||
]
|
||||
yield FileChatDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
else:
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
|
||||
elif isinstance(packet, StreamStopInfo):
|
||||
pass
|
||||
else:
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
tool_result = packet
|
||||
@@ -761,7 +807,6 @@ def stream_chat_message_objects(
|
||||
|
||||
# Post-LLM answer processing
|
||||
try:
|
||||
logger.debug("Post-LLM answer processing")
|
||||
message_specific_citations: MessageSpecificCitations | None = None
|
||||
if reference_db_search_docs:
|
||||
message_specific_citations = _translate_citations(
|
||||
@@ -777,6 +822,7 @@ def stream_chat_message_objects(
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=(
|
||||
qa_docs_response.rephrased_query if qa_docs_response else None
|
||||
@@ -784,21 +830,21 @@ def stream_chat_message_objects(
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=(
|
||||
message_specific_citations.citation_map
|
||||
if message_specific_citations
|
||||
else None
|
||||
),
|
||||
citations=message_specific_citations.citation_map
|
||||
if message_specific_citations
|
||||
else None,
|
||||
error=None,
|
||||
tool_call=(
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
tool_calls=(
|
||||
[
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
]
|
||||
if tool_result
|
||||
else None
|
||||
else []
|
||||
),
|
||||
)
|
||||
|
||||
@@ -822,6 +868,7 @@ def stream_chat_message_objects(
|
||||
def stream_chat_message(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
@@ -831,6 +878,7 @@ def stream_chat_message(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
custom_tool_additional_headers=custom_tool_additional_headers,
|
||||
is_connected=is_connected,
|
||||
|
||||
@@ -9,19 +9,19 @@ prompts:
|
||||
system: >
|
||||
You are a question answering system that is constantly learning and improving.
|
||||
The current date is DANSWER_DATETIME_REPLACEMENT.
|
||||
|
||||
|
||||
You can process and comprehend vast amounts of text and utilize this knowledge to provide
|
||||
grounded, accurate, and concise answers to diverse queries.
|
||||
|
||||
|
||||
You always clearly communicate ANY UNCERTAINTY in your answer.
|
||||
# Task Prompt (as shown in UI)
|
||||
task: >
|
||||
Answer my query based on the documents provided.
|
||||
The documents may not all be relevant, ignore any documents that are not directly relevant
|
||||
to the most recent user query.
|
||||
|
||||
|
||||
I have not read or seen any of the documents and do not want to read them.
|
||||
|
||||
|
||||
If there are no relevant documents, refer to the chat history and your internal knowledge.
|
||||
# Inject a statement at the end of system prompt to inform the LLM of the current date/time
|
||||
# If the DANSWER_DATETIME_REPLACEMENT is set, the date/time is inserted there instead
|
||||
@@ -30,21 +30,21 @@ prompts:
|
||||
# Prompts the LLM to include citations in the for [1], [2] etc.
|
||||
# which get parsed to match the passed in sources
|
||||
include_citations: true
|
||||
|
||||
|
||||
- name: "ImageGeneration"
|
||||
description: "Generates images from user descriptions!"
|
||||
description: "Generates images based on user prompts!"
|
||||
system: >
|
||||
You are an AI image generation assistant. Your role is to create high-quality images based on user descriptions.
|
||||
|
||||
For appropriate requests, you will generate an image that matches the user's requirements.
|
||||
For inappropriate or unsafe requests, you will politely decline and explain why the request cannot be fulfilled.
|
||||
|
||||
You aim to be helpful while maintaining appropriate content standards.
|
||||
You are an advanced image generation system capable of creating diverse and detailed images.
|
||||
|
||||
You can interpret user prompts and generate high-quality, creative images that match their descriptions.
|
||||
|
||||
You always strive to create safe and appropriate content, avoiding any harmful or offensive imagery.
|
||||
task: >
|
||||
Based on the user's description, create a high-quality image that accurately reflects their request.
|
||||
Pay close attention to the specified details, styles, and desired elements.
|
||||
|
||||
If the request is not appropriate or cannot be fulfilled, explain why and suggest alternatives.
|
||||
Generate an image based on the user's description.
|
||||
|
||||
Provide a detailed description of the generated image, including key elements, colors, and composition.
|
||||
|
||||
If the request is not possible or appropriate, explain why and suggest alternatives.
|
||||
datetime_aware: true
|
||||
include_citations: false
|
||||
|
||||
@@ -64,13 +64,14 @@ prompts:
|
||||
datetime_aware: true
|
||||
include_citations: true
|
||||
|
||||
|
||||
- name: "Summarize"
|
||||
description: "Summarize relevant information from retrieved context!"
|
||||
system: >
|
||||
You are a text summarizing assistant that highlights the most important knowledge from the
|
||||
context provided, prioritizing the information that relates to the user query.
|
||||
The current date is DANSWER_DATETIME_REPLACEMENT.
|
||||
|
||||
|
||||
You ARE NOT creative and always stick to the provided documents.
|
||||
If there are no documents, refer to the conversation history.
|
||||
|
||||
@@ -83,6 +84,7 @@ prompts:
|
||||
datetime_aware: true
|
||||
include_citations: true
|
||||
|
||||
|
||||
- name: "Paraphrase"
|
||||
description: "Recites information from retrieved context! Least creative but most safe!"
|
||||
system: >
|
||||
@@ -90,10 +92,10 @@ prompts:
|
||||
The current date is DANSWER_DATETIME_REPLACEMENT.
|
||||
|
||||
You only provide quotes that are EXACT substrings from provided documents!
|
||||
|
||||
|
||||
If there are no documents provided,
|
||||
simply tell the user that there are no documents to reference.
|
||||
|
||||
|
||||
You NEVER generate new text or phrases outside of the citation.
|
||||
DO NOT explain your responses, only provide the quotes and NOTHING ELSE.
|
||||
task: >
|
||||
|
||||
@@ -163,17 +163,6 @@ try:
|
||||
except ValueError:
|
||||
POSTGRES_POOL_RECYCLE = POSTGRES_POOL_RECYCLE_DEFAULT
|
||||
|
||||
# Experimental setting to control idle transactions
|
||||
POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT = 0 # milliseconds
|
||||
try:
|
||||
POSTGRES_IDLE_SESSIONS_TIMEOUT = int(
|
||||
os.environ.get(
|
||||
"POSTGRES_IDLE_SESSIONS_TIMEOUT", POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
POSTGRES_IDLE_SESSIONS_TIMEOUT = POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT
|
||||
|
||||
REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true"
|
||||
REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
|
||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
||||
@@ -262,6 +251,9 @@ ENABLED_CONNECTOR_TYPES = os.environ.get("ENABLED_CONNECTOR_TYPES") or ""
|
||||
# for some connectors
|
||||
ENABLE_EXPENSIVE_EXPERT_CALLS = False
|
||||
|
||||
GOOGLE_DRIVE_INCLUDE_SHARED = False
|
||||
GOOGLE_DRIVE_FOLLOW_SHORTCUTS = False
|
||||
GOOGLE_DRIVE_ONLY_ORG_PUBLIC = False
|
||||
|
||||
# TODO these should be available for frontend configuration, via advanced options expandable
|
||||
WEB_CONNECTOR_IGNORED_CLASSES = os.environ.get(
|
||||
@@ -489,17 +481,3 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
|
||||
|
||||
# JWT configuration
|
||||
JWT_ALGORITHM = "HS256"
|
||||
|
||||
# Super Users
|
||||
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
|
||||
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
|
||||
|
||||
#####
|
||||
# API Key Configs
|
||||
#####
|
||||
# refers to the rounds described here: https://passlib.readthedocs.io/en/stable/lib/passlib.hash.sha256_crypt.html
|
||||
_API_KEY_HASH_ROUNDS_RAW = os.environ.get("API_KEY_HASH_ROUNDS")
|
||||
API_KEY_HASH_ROUNDS = (
|
||||
int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None
|
||||
)
|
||||
|
||||
@@ -125,8 +125,6 @@ class DocumentSource(str, Enum):
|
||||
OCI_STORAGE = "oci_storage"
|
||||
XENFORO = "xenforo"
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
FRESHDESK = "freshdesk"
|
||||
FIREFLIES = "fireflies"
|
||||
|
||||
|
||||
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
|
||||
@@ -226,9 +224,6 @@ class DanswerRedisLocks:
|
||||
PRUNING_LOCK_PREFIX = "da_lock:pruning"
|
||||
INDEXING_METADATA_PREFIX = "da_metadata:indexing"
|
||||
|
||||
SLACK_BOT_LOCK = "da_lock:slack_bot"
|
||||
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
|
||||
|
||||
|
||||
class DanswerCeleryPriority(int, Enum):
|
||||
HIGHEST = 0
|
||||
|
||||
@@ -17,7 +17,6 @@ from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.interfaces import SlimConnector
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
@@ -250,11 +249,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self.cql_time_filter += f" and lastmodified <= '{formatted_end_time}'"
|
||||
return self._fetch_document_batches()
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
|
||||
@@ -23,16 +23,7 @@ def datetime_to_utc(dt: datetime) -> datetime:
|
||||
|
||||
|
||||
def time_str_to_utc(datetime_str: str) -> datetime:
|
||||
try:
|
||||
dt = parse(datetime_str)
|
||||
except ValueError:
|
||||
# Handle malformed timezone by attempting to fix common format issues
|
||||
if "0000" in datetime_str:
|
||||
# Convert "0000" to "+0000" for proper timezone parsing
|
||||
fixed_dt_str = datetime_str.replace(" 0000", " +0000")
|
||||
dt = parse(fixed_dt_str)
|
||||
else:
|
||||
raise
|
||||
dt = parse(datetime_str)
|
||||
return datetime_to_utc(dt)
|
||||
|
||||
|
||||
|
||||
@@ -16,8 +16,6 @@ from danswer.connectors.discourse.connector import DiscourseConnector
|
||||
from danswer.connectors.document360.connector import Document360Connector
|
||||
from danswer.connectors.dropbox.connector import DropboxConnector
|
||||
from danswer.connectors.file.connector import LocalFileConnector
|
||||
from danswer.connectors.fireflies.connector import FirefliesConnector
|
||||
from danswer.connectors.freshdesk.connector import FreshdeskConnector
|
||||
from danswer.connectors.github.connector import GithubConnector
|
||||
from danswer.connectors.gitlab.connector import GitlabConnector
|
||||
from danswer.connectors.gmail.connector import GmailConnector
|
||||
@@ -101,8 +99,6 @@ def identify_connector_class(
|
||||
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.OCI_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.XENFORO: XenforoConnector,
|
||||
DocumentSource.FRESHDESK: FreshdeskConnector,
|
||||
DocumentSource.FIREFLIES: FirefliesConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
|
||||
|
||||
@@ -27,8 +27,8 @@ from danswer.file_processing.extract_file_text import read_pdf_file
|
||||
from danswer.file_processing.extract_file_text import read_text_file
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -123,13 +123,9 @@ def _process_file(
|
||||
"filename",
|
||||
"file_display_name",
|
||||
"title",
|
||||
"connector_type",
|
||||
]
|
||||
}
|
||||
|
||||
source_type_str = all_metadata.get("connector_type")
|
||||
source_type = DocumentSource(source_type_str) if source_type_str else None
|
||||
|
||||
p_owner_names = all_metadata.get("primary_owners")
|
||||
s_owner_names = all_metadata.get("secondary_owners")
|
||||
p_owners = (
|
||||
@@ -149,7 +145,7 @@ def _process_file(
|
||||
sections=[
|
||||
Section(link=all_metadata.get("link"), text=file_content_raw.strip())
|
||||
],
|
||||
source=source_type or DocumentSource.FILE,
|
||||
source=DocumentSource.FILE,
|
||||
semantic_identifier=file_display_name,
|
||||
title=title,
|
||||
doc_updated_at=final_time_updated,
|
||||
|
||||
@@ -1,182 +0,0 @@
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_FIREFLIES_ID_PREFIX = "FIREFLIES_"
|
||||
|
||||
_FIREFLIES_API_URL = "https://api.fireflies.ai/graphql"
|
||||
|
||||
_FIREFLIES_TRANSCRIPT_QUERY_SIZE = 50 # Max page size is 50
|
||||
|
||||
_FIREFLIES_API_QUERY = """
|
||||
query Transcripts($fromDate: DateTime, $toDate: DateTime, $limit: Int!, $skip: Int!) {
|
||||
transcripts(fromDate: $fromDate, toDate: $toDate, limit: $limit, skip: $skip) {
|
||||
id
|
||||
title
|
||||
host_email
|
||||
participants
|
||||
date
|
||||
transcript_url
|
||||
sentences {
|
||||
text
|
||||
speaker_name
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def _create_doc_from_transcript(transcript: dict) -> Document | None:
|
||||
meeting_text = ""
|
||||
sentences = transcript.get("sentences", [])
|
||||
if sentences:
|
||||
for sentence in sentences:
|
||||
meeting_text += sentence.get("speaker_name") or "Unknown Speaker"
|
||||
meeting_text += ": " + sentence.get("text", "") + "\n\n"
|
||||
else:
|
||||
return None
|
||||
|
||||
meeting_link = transcript["transcript_url"]
|
||||
|
||||
fireflies_id = _FIREFLIES_ID_PREFIX + transcript["id"]
|
||||
|
||||
meeting_title = transcript["title"] or "No Title"
|
||||
|
||||
meeting_date_unix = transcript["date"]
|
||||
meeting_date = datetime.fromtimestamp(meeting_date_unix / 1000, tz=timezone.utc)
|
||||
|
||||
meeting_host_email = transcript["host_email"]
|
||||
host_email_user_info = [BasicExpertInfo(email=meeting_host_email)]
|
||||
|
||||
meeting_participants_email_list = []
|
||||
for participant in transcript.get("participants", []):
|
||||
if participant != meeting_host_email and participant:
|
||||
meeting_participants_email_list.append(BasicExpertInfo(email=participant))
|
||||
|
||||
return Document(
|
||||
id=fireflies_id,
|
||||
sections=[
|
||||
Section(
|
||||
link=meeting_link,
|
||||
text=meeting_text,
|
||||
)
|
||||
],
|
||||
source=DocumentSource.FIREFLIES,
|
||||
semantic_identifier=meeting_title,
|
||||
metadata={},
|
||||
doc_updated_at=meeting_date,
|
||||
primary_owners=host_email_user_info,
|
||||
secondary_owners=meeting_participants_email_list,
|
||||
)
|
||||
|
||||
|
||||
class FirefliesConnector(PollConnector, LoadConnector):
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
|
||||
def load_credentials(self, credentials: dict[str, str]) -> None:
|
||||
api_key = credentials.get("fireflies_api_key")
|
||||
|
||||
if not isinstance(api_key, str):
|
||||
raise ConnectorMissingCredentialError(
|
||||
"The Fireflies API key must be a string"
|
||||
)
|
||||
|
||||
self.api_key = api_key
|
||||
|
||||
return None
|
||||
|
||||
def _fetch_transcripts(
|
||||
self, start_datetime: str | None = None, end_datetime: str | None = None
|
||||
) -> Iterator[List[dict]]:
|
||||
if self.api_key is None:
|
||||
raise ConnectorMissingCredentialError("Missing API key")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + self.api_key,
|
||||
}
|
||||
|
||||
skip = 0
|
||||
variables: dict[str, int | str] = {
|
||||
"limit": _FIREFLIES_TRANSCRIPT_QUERY_SIZE,
|
||||
}
|
||||
|
||||
if start_datetime:
|
||||
variables["fromDate"] = start_datetime
|
||||
if end_datetime:
|
||||
variables["toDate"] = end_datetime
|
||||
|
||||
while True:
|
||||
variables["skip"] = skip
|
||||
response = requests.post(
|
||||
_FIREFLIES_API_URL,
|
||||
headers=headers,
|
||||
json={"query": _FIREFLIES_API_QUERY, "variables": variables},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
if response.status_code == 204:
|
||||
break
|
||||
|
||||
recieved_transcripts = response.json()
|
||||
parsed_transcripts = recieved_transcripts.get("data", {}).get(
|
||||
"transcripts", []
|
||||
)
|
||||
|
||||
yield parsed_transcripts
|
||||
|
||||
if len(parsed_transcripts) < _FIREFLIES_TRANSCRIPT_QUERY_SIZE:
|
||||
break
|
||||
|
||||
skip += _FIREFLIES_TRANSCRIPT_QUERY_SIZE
|
||||
|
||||
def _process_transcripts(
|
||||
self, start: str | None = None, end: str | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch: List[Document] = []
|
||||
|
||||
for transcript_batch in self._fetch_transcripts(start, end):
|
||||
for transcript in transcript_batch:
|
||||
if doc := _create_doc_from_transcript(transcript):
|
||||
doc_batch.append(doc)
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._process_transcripts()
|
||||
|
||||
def poll_source(
|
||||
self, start_unixtime: SecondsSinceUnixEpoch, end_unixtime: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_datetime = datetime.fromtimestamp(
|
||||
start_unixtime, tz=timezone.utc
|
||||
).strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
end_datetime = datetime.fromtimestamp(end_unixtime, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%dT%H:%M:%S.000Z"
|
||||
)
|
||||
|
||||
yield from self._process_transcripts(start_datetime, end_datetime)
|
||||
@@ -1,239 +0,0 @@
|
||||
import json
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.html_utils import parse_html_page_basic
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_FRESHDESK_ID_PREFIX = "FRESHDESK_"
|
||||
|
||||
|
||||
_TICKET_FIELDS_TO_INCLUDE = {
|
||||
"fr_escalated",
|
||||
"spam",
|
||||
"priority",
|
||||
"source",
|
||||
"status",
|
||||
"type",
|
||||
"is_escalated",
|
||||
"tags",
|
||||
"nr_due_by",
|
||||
"nr_escalated",
|
||||
"cc_emails",
|
||||
"fwd_emails",
|
||||
"reply_cc_emails",
|
||||
"ticket_cc_emails",
|
||||
"support_email",
|
||||
"to_emails",
|
||||
}
|
||||
|
||||
_SOURCE_NUMBER_TYPE_MAP: dict[int, str] = {
|
||||
1: "Email",
|
||||
2: "Portal",
|
||||
3: "Phone",
|
||||
7: "Chat",
|
||||
9: "Feedback Widget",
|
||||
10: "Outbound Email",
|
||||
}
|
||||
|
||||
_PRIORITY_NUMBER_TYPE_MAP: dict[int, str] = {
|
||||
1: "low",
|
||||
2: "medium",
|
||||
3: "high",
|
||||
4: "urgent",
|
||||
}
|
||||
|
||||
_STATUS_NUMBER_TYPE_MAP: dict[int, str] = {
|
||||
2: "open",
|
||||
3: "pending",
|
||||
4: "resolved",
|
||||
5: "closed",
|
||||
}
|
||||
|
||||
|
||||
def _create_metadata_from_ticket(ticket: dict) -> dict:
|
||||
metadata: dict[str, str | list[str]] = {}
|
||||
# Combine all emails into a list so there are no repeated emails
|
||||
email_data: set[str] = set()
|
||||
|
||||
for key, value in ticket.items():
|
||||
# Skip fields that aren't useful for embedding
|
||||
if key not in _TICKET_FIELDS_TO_INCLUDE:
|
||||
continue
|
||||
|
||||
# Skip empty fields
|
||||
if not value or value == "[]":
|
||||
continue
|
||||
|
||||
# Convert strings or lists to strings
|
||||
stringified_value: str | list[str]
|
||||
if isinstance(value, list):
|
||||
stringified_value = [str(item) for item in value]
|
||||
else:
|
||||
stringified_value = str(value)
|
||||
|
||||
if "email" in key:
|
||||
if isinstance(stringified_value, list):
|
||||
email_data.update(stringified_value)
|
||||
else:
|
||||
email_data.add(stringified_value)
|
||||
else:
|
||||
metadata[key] = stringified_value
|
||||
|
||||
if email_data:
|
||||
metadata["emails"] = list(email_data)
|
||||
|
||||
# Convert source numbers to human-parsable string
|
||||
if source_number := ticket.get("source"):
|
||||
metadata["source"] = _SOURCE_NUMBER_TYPE_MAP.get(
|
||||
source_number, "Unknown Source Type"
|
||||
)
|
||||
|
||||
# Convert priority numbers to human-parsable string
|
||||
if priority_number := ticket.get("priority"):
|
||||
metadata["priority"] = _PRIORITY_NUMBER_TYPE_MAP.get(
|
||||
priority_number, "Unknown Priority"
|
||||
)
|
||||
|
||||
# Convert status to human-parsable string
|
||||
if status_number := ticket.get("status"):
|
||||
metadata["status"] = _STATUS_NUMBER_TYPE_MAP.get(
|
||||
status_number, "Unknown Status"
|
||||
)
|
||||
|
||||
due_by = datetime.fromisoformat(ticket["due_by"].replace("Z", "+00:00"))
|
||||
metadata["overdue"] = str(datetime.now(timezone.utc) > due_by)
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def _create_doc_from_ticket(ticket: dict, domain: str) -> Document:
|
||||
# Use the ticket description as the text
|
||||
text = f"Ticket description: {parse_html_page_basic(ticket.get('description_text', ''))}"
|
||||
metadata = _create_metadata_from_ticket(ticket)
|
||||
|
||||
# This is also used in the ID because it is more unique than the just the ticket ID
|
||||
link = f"https://{domain}.freshdesk.com/helpdesk/tickets/{ticket['id']}"
|
||||
|
||||
return Document(
|
||||
id=_FRESHDESK_ID_PREFIX + link,
|
||||
sections=[
|
||||
Section(
|
||||
link=link,
|
||||
text=text,
|
||||
)
|
||||
],
|
||||
source=DocumentSource.FRESHDESK,
|
||||
semantic_identifier=ticket["subject"],
|
||||
metadata=metadata,
|
||||
doc_updated_at=datetime.fromisoformat(
|
||||
ticket["updated_at"].replace("Z", "+00:00")
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class FreshdeskConnector(PollConnector, LoadConnector):
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
|
||||
def load_credentials(self, credentials: dict[str, str | int]) -> None:
|
||||
api_key = credentials.get("freshdesk_api_key")
|
||||
domain = credentials.get("freshdesk_domain")
|
||||
password = credentials.get("freshdesk_password")
|
||||
|
||||
if not all(isinstance(cred, str) for cred in [domain, api_key, password]):
|
||||
raise ConnectorMissingCredentialError(
|
||||
"All Freshdesk credentials must be strings"
|
||||
)
|
||||
|
||||
self.api_key = str(api_key)
|
||||
self.domain = str(domain)
|
||||
self.password = str(password)
|
||||
|
||||
def _fetch_tickets(
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> Iterator[List[dict]]:
|
||||
"""
|
||||
'end' is not currently used, so we may double fetch tickets created after the indexing
|
||||
starts but before the actual call is made.
|
||||
|
||||
To use 'end' would require us to use the search endpoint but it has limitations,
|
||||
namely having to fetch all IDs and then individually fetch each ticket because there is no
|
||||
'include' field available for this endpoint:
|
||||
https://developers.freshdesk.com/api/#filter_tickets
|
||||
"""
|
||||
if self.api_key is None or self.domain is None or self.password is None:
|
||||
raise ConnectorMissingCredentialError("freshdesk")
|
||||
|
||||
base_url = f"https://{self.domain}.freshdesk.com/api/v2/tickets"
|
||||
params: dict[str, int | str] = {
|
||||
"include": "description",
|
||||
"per_page": 50,
|
||||
"page": 1,
|
||||
}
|
||||
|
||||
if start:
|
||||
params["updated_since"] = start.isoformat()
|
||||
|
||||
while True:
|
||||
response = requests.get(
|
||||
base_url, auth=(self.api_key, self.password), params=params
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
if response.status_code == 204:
|
||||
break
|
||||
|
||||
tickets = json.loads(response.content)
|
||||
logger.info(
|
||||
f"Fetched {len(tickets)} tickets from Freshdesk API (Page {params['page']})"
|
||||
)
|
||||
|
||||
yield tickets
|
||||
|
||||
if len(tickets) < int(params["per_page"]):
|
||||
break
|
||||
|
||||
params["page"] = int(params["page"]) + 1
|
||||
|
||||
def _process_tickets(
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch: List[Document] = []
|
||||
|
||||
for ticket_batch in self._fetch_tickets(start, end):
|
||||
for ticket in ticket_batch:
|
||||
doc_batch.append(_create_doc_from_ticket(ticket, self.domain))
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._process_tickets()
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
yield from self._process_tickets(start_datetime, end_datetime)
|
||||
@@ -1,360 +1,221 @@
|
||||
from base64 import urlsafe_b64decode
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient import discovery # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.google_utils.google_auth import get_google_creds
|
||||
from danswer.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
from danswer.connectors.google_utils.resources import get_admin_service
|
||||
from danswer.connectors.google_utils.resources import get_gmail_service
|
||||
from danswer.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
from danswer.connectors.gmail.connector_auth import (
|
||||
get_gmail_creds_for_authorized_user,
|
||||
)
|
||||
from danswer.connectors.gmail.connector_auth import (
|
||||
get_gmail_creds_for_service_account,
|
||||
)
|
||||
from danswer.connectors.gmail.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
)
|
||||
from danswer.connectors.gmail.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
|
||||
from danswer.connectors.gmail.constants import (
|
||||
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.google_utils.shared_constants import MISSING_SCOPES_ERROR_STR
|
||||
from danswer.connectors.google_utils.shared_constants import ONYX_SCOPE_INSTRUCTIONS
|
||||
from danswer.connectors.google_utils.shared_constants import SLIM_BATCH_SIZE
|
||||
from danswer.connectors.google_utils.shared_constants import USER_FIELDS
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.interfaces import SlimConnector
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.models import SlimDocument
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# This is for the initial list call to get the thread ids
|
||||
THREAD_LIST_FIELDS = "nextPageToken, threads(id)"
|
||||
|
||||
# These are the fields to retrieve using the ID from the initial list call
|
||||
PARTS_FIELDS = "parts(body(data), mimeType)"
|
||||
PAYLOAD_FIELDS = f"payload(headers, {PARTS_FIELDS})"
|
||||
MESSAGES_FIELDS = f"messages(id, {PAYLOAD_FIELDS})"
|
||||
THREADS_FIELDS = f"threads(id, {MESSAGES_FIELDS})"
|
||||
THREAD_FIELDS = f"id, {MESSAGES_FIELDS}"
|
||||
|
||||
EMAIL_FIELDS = [
|
||||
"cc",
|
||||
"bcc",
|
||||
"from",
|
||||
"to",
|
||||
]
|
||||
|
||||
add_retries = retry_builder(tries=50, max_delay=30)
|
||||
|
||||
|
||||
def _build_time_range_query(
|
||||
time_range_start: SecondsSinceUnixEpoch | None = None,
|
||||
time_range_end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> str | None:
|
||||
query = ""
|
||||
if time_range_start is not None and time_range_start != 0:
|
||||
query += f"after:{int(time_range_start)}"
|
||||
if time_range_end is not None and time_range_end != 0:
|
||||
query += f" before:{int(time_range_end)}"
|
||||
query = query.strip()
|
||||
|
||||
if len(query) == 0:
|
||||
return None
|
||||
|
||||
return query
|
||||
|
||||
|
||||
def _clean_email_and_extract_name(email: str) -> tuple[str, str | None]:
|
||||
email = email.strip()
|
||||
if "<" in email and ">" in email:
|
||||
# Handle format: "Display Name <email@domain.com>"
|
||||
display_name = email[: email.find("<")].strip()
|
||||
email_address = email[email.find("<") + 1 : email.find(">")].strip()
|
||||
return email_address, display_name if display_name else None
|
||||
else:
|
||||
# Handle plain email address
|
||||
return email.strip(), None
|
||||
|
||||
|
||||
def _get_owners_from_emails(emails: dict[str, str | None]) -> list[BasicExpertInfo]:
|
||||
owners = []
|
||||
for email, names in emails.items():
|
||||
if names:
|
||||
name_parts = names.split(" ")
|
||||
first_name = " ".join(name_parts[:-1])
|
||||
last_name = name_parts[-1]
|
||||
else:
|
||||
first_name = None
|
||||
last_name = None
|
||||
owners.append(
|
||||
BasicExpertInfo(email=email, first_name=first_name, last_name=last_name)
|
||||
)
|
||||
return owners
|
||||
|
||||
|
||||
def _get_message_body(payload: dict[str, Any]) -> str:
|
||||
parts = payload.get("parts", [])
|
||||
message_body = ""
|
||||
for part in parts:
|
||||
mime_type = part.get("mimeType")
|
||||
body = part.get("body")
|
||||
if mime_type == "text/plain" and body:
|
||||
data = body.get("data", "")
|
||||
text = urlsafe_b64decode(data).decode()
|
||||
message_body += text
|
||||
return message_body
|
||||
|
||||
|
||||
def message_to_section(message: Dict[str, Any]) -> tuple[Section, dict[str, str]]:
|
||||
link = f"https://mail.google.com/mail/u/0/#inbox/{message['id']}"
|
||||
|
||||
payload = message.get("payload", {})
|
||||
headers = payload.get("headers", [])
|
||||
metadata: dict[str, Any] = {}
|
||||
for header in headers:
|
||||
name = header.get("name").lower()
|
||||
value = header.get("value")
|
||||
if name in EMAIL_FIELDS:
|
||||
metadata[name] = value
|
||||
if name == "subject":
|
||||
metadata["subject"] = value
|
||||
if name == "date":
|
||||
metadata["updated_at"] = value
|
||||
|
||||
if labels := message.get("labelIds"):
|
||||
metadata["labels"] = labels
|
||||
|
||||
message_data = ""
|
||||
for name, value in metadata.items():
|
||||
# updated at isnt super useful for the llm
|
||||
if name != "updated_at":
|
||||
message_data += f"{name}: {value}\n"
|
||||
|
||||
message_body_text: str = _get_message_body(payload)
|
||||
|
||||
return Section(link=link, text=message_body_text + message_data), metadata
|
||||
|
||||
|
||||
def thread_to_document(full_thread: Dict[str, Any]) -> Document | None:
|
||||
all_messages = full_thread.get("messages", [])
|
||||
if not all_messages:
|
||||
return None
|
||||
|
||||
sections = []
|
||||
semantic_identifier = ""
|
||||
updated_at = None
|
||||
from_emails: dict[str, str | None] = {}
|
||||
other_emails: dict[str, str | None] = {}
|
||||
for message in all_messages:
|
||||
section, message_metadata = message_to_section(message)
|
||||
sections.append(section)
|
||||
|
||||
for name, value in message_metadata.items():
|
||||
if name in EMAIL_FIELDS:
|
||||
email, display_name = _clean_email_and_extract_name(value)
|
||||
if name == "from":
|
||||
from_emails[email] = (
|
||||
display_name if not from_emails.get(email) else None
|
||||
)
|
||||
else:
|
||||
other_emails[email] = (
|
||||
display_name if not other_emails.get(email) else None
|
||||
)
|
||||
|
||||
# If we haven't set the semantic identifier yet, set it to the subject of the first message
|
||||
if not semantic_identifier:
|
||||
semantic_identifier = message_metadata.get("subject", "")
|
||||
|
||||
if message_metadata.get("updated_at"):
|
||||
updated_at = message_metadata.get("updated_at")
|
||||
|
||||
updated_at_datetime = None
|
||||
if updated_at:
|
||||
updated_at_datetime = time_str_to_utc(updated_at)
|
||||
|
||||
id = full_thread.get("id")
|
||||
if not id:
|
||||
raise ValueError("Thread ID is required")
|
||||
|
||||
primary_owners = _get_owners_from_emails(from_emails)
|
||||
secondary_owners = _get_owners_from_emails(other_emails)
|
||||
|
||||
return Document(
|
||||
id=id,
|
||||
semantic_identifier=semantic_identifier,
|
||||
sections=sections,
|
||||
source=DocumentSource.GMAIL,
|
||||
# This is used to perform permission sync
|
||||
primary_owners=primary_owners,
|
||||
secondary_owners=secondary_owners,
|
||||
doc_updated_at=updated_at_datetime,
|
||||
# Not adding emails to metadata because it's already in the sections
|
||||
metadata={},
|
||||
)
|
||||
|
||||
|
||||
class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class GmailConnector(LoadConnector, PollConnector):
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
|
||||
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
self._primary_admin_email: str | None = None
|
||||
|
||||
@property
|
||||
def primary_admin_email(self) -> str:
|
||||
if self._primary_admin_email is None:
|
||||
raise RuntimeError(
|
||||
"Primary admin email missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
return self._primary_admin_email
|
||||
|
||||
@property
|
||||
def google_domain(self) -> str:
|
||||
if self._primary_admin_email is None:
|
||||
raise RuntimeError(
|
||||
"Primary admin email missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
return self._primary_admin_email.split("@")[-1]
|
||||
|
||||
@property
|
||||
def creds(self) -> OAuthCredentials | ServiceAccountCredentials:
|
||||
if self._creds is None:
|
||||
raise RuntimeError(
|
||||
"Creds missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
return self._creds
|
||||
self.creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
|
||||
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
|
||||
self._primary_admin_email = primary_admin_email
|
||||
"""Checks for two different types of credentials.
|
||||
(1) A credential which holds a token acquired via a user going thorugh
|
||||
the Google OAuth flow.
|
||||
(2) A credential which holds a service account key JSON file, which
|
||||
can then be used to impersonate any user in the workspace.
|
||||
"""
|
||||
creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
new_creds_dict = None
|
||||
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
|
||||
access_token_json_str = cast(
|
||||
str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]
|
||||
)
|
||||
creds = get_gmail_creds_for_authorized_user(
|
||||
token_json_str=access_token_json_str
|
||||
)
|
||||
|
||||
self._creds, new_creds_dict = get_google_creds(
|
||||
credentials=credentials,
|
||||
source=DocumentSource.GMAIL,
|
||||
)
|
||||
# tell caller to update token stored in DB if it has changed
|
||||
# (e.g. the token has been refreshed)
|
||||
new_creds_json_str = creds.to_json() if creds else ""
|
||||
if new_creds_json_str != access_token_json_str:
|
||||
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
|
||||
|
||||
if GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
|
||||
service_account_key_json_str = credentials[
|
||||
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
|
||||
]
|
||||
creds = get_gmail_creds_for_service_account(
|
||||
service_account_key_json_str=service_account_key_json_str
|
||||
)
|
||||
|
||||
# "Impersonate" a user if one is specified
|
||||
delegated_user_email = cast(
|
||||
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
|
||||
)
|
||||
if delegated_user_email:
|
||||
creds = creds.with_subject(delegated_user_email) if creds else None # type: ignore
|
||||
|
||||
if creds is None:
|
||||
raise PermissionError(
|
||||
"Unable to access Gmail - unknown credential structure."
|
||||
)
|
||||
|
||||
self.creds = creds
|
||||
return new_creds_dict
|
||||
|
||||
def _get_all_user_emails(self) -> list[str]:
|
||||
admin_service = get_admin_service(self.creds, self.primary_admin_email)
|
||||
emails = []
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
list_key="users",
|
||||
fields=USER_FIELDS,
|
||||
domain=self.google_domain,
|
||||
):
|
||||
if email := user.get("primaryEmail"):
|
||||
emails.append(email)
|
||||
return emails
|
||||
def _get_email_body(self, payload: dict[str, Any]) -> str:
|
||||
parts = payload.get("parts", [])
|
||||
email_body = ""
|
||||
for part in parts:
|
||||
mime_type = part.get("mimeType")
|
||||
body = part.get("body")
|
||||
if mime_type == "text/plain":
|
||||
data = body.get("data", "")
|
||||
text = urlsafe_b64decode(data).decode()
|
||||
email_body += text
|
||||
return email_body
|
||||
|
||||
def _fetch_threads(
|
||||
def _email_to_document(self, full_email: Dict[str, Any]) -> Document:
|
||||
email_id = full_email["id"]
|
||||
payload = full_email["payload"]
|
||||
headers = payload.get("headers")
|
||||
labels = full_email.get("labelIds", [])
|
||||
metadata = {}
|
||||
if headers:
|
||||
for header in headers:
|
||||
name = header.get("name").lower()
|
||||
value = header.get("value")
|
||||
if name in ["from", "to", "subject", "date", "cc", "bcc"]:
|
||||
metadata[name] = value
|
||||
email_data = ""
|
||||
for name, value in metadata.items():
|
||||
email_data += f"{name}: {value}\n"
|
||||
metadata["labels"] = labels
|
||||
logger.debug(f"{email_data}")
|
||||
email_body_text: str = self._get_email_body(payload)
|
||||
date_str = metadata.get("date")
|
||||
email_updated_at = time_str_to_utc(date_str) if date_str else None
|
||||
link = f"https://mail.google.com/mail/u/0/#inbox/{email_id}"
|
||||
return Document(
|
||||
id=email_id,
|
||||
sections=[Section(link=link, text=email_data + email_body_text)],
|
||||
source=DocumentSource.GMAIL,
|
||||
title=metadata.get("subject"),
|
||||
semantic_identifier=metadata.get("subject", "Untitled Email"),
|
||||
doc_updated_at=email_updated_at,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_time_range_query(
|
||||
time_range_start: SecondsSinceUnixEpoch | None = None,
|
||||
time_range_end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> str | None:
|
||||
query = ""
|
||||
if time_range_start is not None and time_range_start != 0:
|
||||
query += f"after:{int(time_range_start)}"
|
||||
if time_range_end is not None and time_range_end != 0:
|
||||
query += f" before:{int(time_range_end)}"
|
||||
query = query.strip()
|
||||
|
||||
if len(query) == 0:
|
||||
return None
|
||||
|
||||
return query
|
||||
|
||||
def _fetch_mails_from_gmail(
|
||||
self,
|
||||
time_range_start: SecondsSinceUnixEpoch | None = None,
|
||||
time_range_end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
query = _build_time_range_query(time_range_start, time_range_end)
|
||||
doc_batch = []
|
||||
for user_email in self._get_all_user_emails():
|
||||
gmail_service = get_gmail_service(self.creds, user_email)
|
||||
for thread in execute_paginated_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().list,
|
||||
list_key="threads",
|
||||
userId=user_email,
|
||||
fields=THREAD_LIST_FIELDS,
|
||||
q=query,
|
||||
):
|
||||
full_threads = execute_paginated_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().get,
|
||||
list_key=None,
|
||||
userId=user_email,
|
||||
fields=THREAD_FIELDS,
|
||||
id=thread["id"],
|
||||
if self.creds is None:
|
||||
raise PermissionError("Not logged into Gmail")
|
||||
page_token = ""
|
||||
query = GmailConnector._build_time_range_query(time_range_start, time_range_end)
|
||||
service = discovery.build("gmail", "v1", credentials=self.creds)
|
||||
while page_token is not None:
|
||||
result = (
|
||||
service.users()
|
||||
.messages()
|
||||
.list(
|
||||
userId="me",
|
||||
pageToken=page_token,
|
||||
q=query,
|
||||
maxResults=self.batch_size,
|
||||
)
|
||||
# full_threads is an iterator containing a single thread
|
||||
# so we need to convert it to a list and grab the first element
|
||||
full_thread = list(full_threads)[0]
|
||||
doc = thread_to_document(full_thread)
|
||||
if doc is None:
|
||||
continue
|
||||
.execute()
|
||||
)
|
||||
page_token = result.get("nextPageToken")
|
||||
messages = result.get("messages", [])
|
||||
doc_batch = []
|
||||
for message in messages:
|
||||
message_id = message["id"]
|
||||
msg = (
|
||||
service.users()
|
||||
.messages()
|
||||
.get(userId="me", id=message_id, format="full")
|
||||
.execute()
|
||||
)
|
||||
doc = self._email_to_document(msg)
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) > self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def _fetch_slim_threads(
|
||||
self,
|
||||
time_range_start: SecondsSinceUnixEpoch | None = None,
|
||||
time_range_end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
query = _build_time_range_query(time_range_start, time_range_end)
|
||||
doc_batch = []
|
||||
for user_email in self._get_all_user_emails():
|
||||
gmail_service = get_gmail_service(self.creds, user_email)
|
||||
for thread in execute_paginated_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().list,
|
||||
list_key="threads",
|
||||
userId=user_email,
|
||||
fields=THREAD_LIST_FIELDS,
|
||||
q=query,
|
||||
):
|
||||
doc_batch.append(
|
||||
SlimDocument(
|
||||
id=thread["id"],
|
||||
perm_sync_data={"user_email": user_email},
|
||||
)
|
||||
)
|
||||
if len(doc_batch) > SLIM_BATCH_SIZE:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
if len(doc_batch) > 0:
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
try:
|
||||
yield from self._fetch_threads()
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
yield from self._fetch_mails_from_gmail()
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
try:
|
||||
yield from self._fetch_threads(start, end)
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
try:
|
||||
yield from self._fetch_slim_threads(start, end)
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
yield from self._fetch_mails_from_gmail(start, end)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
import json
|
||||
import os
|
||||
|
||||
service_account_json_path = os.environ.get("GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH")
|
||||
if not service_account_json_path:
|
||||
raise ValueError(
|
||||
"Please set GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH environment variable"
|
||||
)
|
||||
with open(service_account_json_path) as f:
|
||||
creds = json.load(f)
|
||||
|
||||
credentials_dict = {
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY: json.dumps(creds),
|
||||
}
|
||||
delegated_user = os.environ.get("GMAIL_DELEGATED_USER")
|
||||
if delegated_user:
|
||||
credentials_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user
|
||||
|
||||
connector = GmailConnector()
|
||||
connector.load_credentials(
|
||||
json.loads(credentials_dict[DB_CREDENTIALS_DICT_TOKEN_KEY])
|
||||
)
|
||||
document_batch_generator = connector.load_from_state()
|
||||
for document_batch in document_batch_generator:
|
||||
print(document_batch)
|
||||
break
|
||||
|
||||
197
backend/danswer/connectors/gmail/connector_auth.py
Normal file
197
backend/danswer/connectors/gmail/connector_auth.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import json
|
||||
from typing import cast
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import ParseResult
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from google.auth.transport.requests import Request # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import KV_CRED_KEY
|
||||
from danswer.configs.constants import KV_GMAIL_CRED_KEY
|
||||
from danswer.configs.constants import KV_GMAIL_SERVICE_ACCOUNT_KEY
|
||||
from danswer.connectors.gmail.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
)
|
||||
from danswer.connectors.gmail.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
|
||||
from danswer.connectors.gmail.constants import (
|
||||
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.gmail.constants import SCOPES
|
||||
from danswer.db.credentials import update_credential_json
|
||||
from danswer.db.models import User
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.server.documents.models import GoogleAppCredentials
|
||||
from danswer.server.documents.models import GoogleServiceAccountKey
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _build_frontend_gmail_redirect() -> str:
|
||||
return f"{WEB_DOMAIN}/admin/connectors/gmail/auth/callback"
|
||||
|
||||
|
||||
def get_gmail_creds_for_authorized_user(
|
||||
token_json_str: str,
|
||||
) -> OAuthCredentials | None:
|
||||
creds_json = json.loads(token_json_str)
|
||||
creds = OAuthCredentials.from_authorized_user_info(creds_json, SCOPES)
|
||||
if creds.valid:
|
||||
return creds
|
||||
|
||||
if creds.expired and creds.refresh_token:
|
||||
try:
|
||||
creds.refresh(Request())
|
||||
if creds.valid:
|
||||
logger.notice("Refreshed Gmail tokens.")
|
||||
return creds
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to refresh gmail access token due to: {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_gmail_creds_for_service_account(
|
||||
service_account_key_json_str: str,
|
||||
) -> ServiceAccountCredentials | None:
|
||||
service_account_key = json.loads(service_account_key_json_str)
|
||||
creds = ServiceAccountCredentials.from_service_account_info(
|
||||
service_account_key, scopes=SCOPES
|
||||
)
|
||||
if not creds.valid or not creds.expired:
|
||||
creds.refresh(Request())
|
||||
return creds if creds.valid else None
|
||||
|
||||
|
||||
def verify_csrf(credential_id: int, state: str) -> None:
|
||||
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
|
||||
if csrf != state:
|
||||
raise PermissionError(
|
||||
"State from Gmail Connector callback does not match expected"
|
||||
)
|
||||
|
||||
|
||||
def get_gmail_auth_url(credential_id: int) -> str:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
scopes=SCOPES,
|
||||
redirect_uri=_build_frontend_gmail_redirect(),
|
||||
)
|
||||
auth_url, _ = flow.authorization_url(prompt="consent")
|
||||
|
||||
parsed_url = cast(ParseResult, urlparse(auth_url))
|
||||
params = parse_qs(parsed_url.query)
|
||||
|
||||
get_kv_store().store(
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
) # type: ignore
|
||||
return str(auth_url)
|
||||
|
||||
|
||||
def get_auth_url(credential_id: int) -> str:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
scopes=SCOPES,
|
||||
redirect_uri=_build_frontend_gmail_redirect(),
|
||||
)
|
||||
auth_url, _ = flow.authorization_url(prompt="consent")
|
||||
|
||||
parsed_url = cast(ParseResult, urlparse(auth_url))
|
||||
params = parse_qs(parsed_url.query)
|
||||
|
||||
get_kv_store().store(
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
) # type: ignore
|
||||
return str(auth_url)
|
||||
|
||||
|
||||
def update_gmail_credential_access_tokens(
|
||||
auth_code: str,
|
||||
credential_id: int,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> OAuthCredentials | None:
|
||||
app_credentials = get_google_app_gmail_cred()
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
app_credentials.model_dump(),
|
||||
scopes=SCOPES,
|
||||
redirect_uri=_build_frontend_gmail_redirect(),
|
||||
)
|
||||
flow.fetch_token(code=auth_code)
|
||||
creds = flow.credentials
|
||||
token_json_str = creds.to_json()
|
||||
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str}
|
||||
|
||||
if not update_credential_json(credential_id, new_creds_dict, user, db_session):
|
||||
return None
|
||||
return creds
|
||||
|
||||
|
||||
def build_service_account_creds(
|
||||
delegated_user_email: str | None = None,
|
||||
) -> CredentialBase:
|
||||
service_account_key = get_gmail_service_account_key()
|
||||
|
||||
credential_dict = {
|
||||
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: service_account_key.json(),
|
||||
}
|
||||
if delegated_user_email:
|
||||
credential_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user_email
|
||||
|
||||
return CredentialBase(
|
||||
source=DocumentSource.GMAIL,
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
)
|
||||
|
||||
|
||||
def get_google_app_gmail_cred() -> GoogleAppCredentials:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
return GoogleAppCredentials(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_google_app_gmail_cred(app_credentials: GoogleAppCredentials) -> None:
|
||||
get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True)
|
||||
|
||||
|
||||
def delete_google_app_gmail_cred() -> None:
|
||||
get_kv_store().delete(KV_GMAIL_CRED_KEY)
|
||||
|
||||
|
||||
def get_gmail_service_account_key() -> GoogleServiceAccountKey:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
return GoogleServiceAccountKey(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_gmail_service_account_key(
|
||||
service_account_key: GoogleServiceAccountKey,
|
||||
) -> None:
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
)
|
||||
|
||||
|
||||
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
)
|
||||
|
||||
|
||||
def delete_gmail_service_account_key() -> None:
|
||||
get_kv_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
|
||||
|
||||
|
||||
def delete_service_account_key() -> None:
|
||||
get_kv_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
|
||||
4
backend/danswer/connectors/gmail/constants.py
Normal file
4
backend/danswer/connectors/gmail/constants.py
Normal file
@@ -0,0 +1,4 @@
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "gmail_tokens"
|
||||
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "gmail_service_account_key"
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "gmail_delegated_user"
|
||||
SCOPES = ["https://www.googleapis.com/auth/gmail.readonly"]
|
||||
@@ -1,400 +1,556 @@
|
||||
from collections.abc import Callable
|
||||
import io
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient import discovery # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import GOOGLE_DRIVE_FOLLOW_SHORTCUTS
|
||||
from danswer.configs.app_configs import GOOGLE_DRIVE_INCLUDE_SHARED
|
||||
from danswer.configs.app_configs import GOOGLE_DRIVE_ONLY_ORG_PUBLIC
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.google_drive.doc_conversion import build_slim_document
|
||||
from danswer.connectors.google_drive.doc_conversion import (
|
||||
convert_drive_item_to_document,
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.connectors.google_drive.connector_auth import get_google_drive_creds
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.file_retrieval import crawl_folders_for_files
|
||||
from danswer.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
|
||||
from danswer.connectors.google_drive.file_retrieval import get_files_in_shared_drive
|
||||
from danswer.connectors.google_drive.models import GoogleDriveFileType
|
||||
from danswer.connectors.google_utils.google_auth import get_google_creds
|
||||
from danswer.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
from danswer.connectors.google_utils.resources import get_admin_service
|
||||
from danswer.connectors.google_utils.resources import get_drive_service
|
||||
from danswer.connectors.google_utils.resources import get_google_docs_service
|
||||
from danswer.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.google_utils.shared_constants import MISSING_SCOPES_ERROR_STR
|
||||
from danswer.connectors.google_utils.shared_constants import ONYX_SCOPE_INSTRUCTIONS
|
||||
from danswer.connectors.google_utils.shared_constants import SCOPE_DOC_URL
|
||||
from danswer.connectors.google_utils.shared_constants import SLIM_BATCH_SIZE
|
||||
from danswer.connectors.google_utils.shared_constants import USER_FIELDS
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.interfaces import SlimConnector
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import docx_to_text
|
||||
from danswer.file_processing.extract_file_text import pptx_to_text
|
||||
from danswer.file_processing.extract_file_text import read_pdf_file
|
||||
from danswer.file_processing.unstructured import get_unstructured_api_key
|
||||
from danswer.file_processing.unstructured import unstructured_to_text
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
# TODO: Improve this by using the batch utility: https://googleapis.github.io/google-api-python-client/docs/batch.html
|
||||
# All file retrievals could be batched and made at once
|
||||
|
||||
DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder"
|
||||
DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut"
|
||||
UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now
|
||||
|
||||
|
||||
def _extract_str_list_from_comma_str(string: str | None) -> list[str]:
|
||||
if not string:
|
||||
return []
|
||||
return [s.strip() for s in string.split(",") if s.strip()]
|
||||
class GDriveMimeType(str, Enum):
|
||||
DOC = "application/vnd.google-apps.document"
|
||||
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
|
||||
PDF = "application/pdf"
|
||||
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
PPT = "application/vnd.google-apps.presentation"
|
||||
POWERPOINT = (
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
)
|
||||
PLAIN_TEXT = "text/plain"
|
||||
MARKDOWN = "text/markdown"
|
||||
|
||||
|
||||
def _extract_ids_from_urls(urls: list[str]) -> list[str]:
|
||||
return [url.split("/")[-1] for url in urls]
|
||||
GoogleDriveFileType = dict[str, Any]
|
||||
|
||||
# Google Drive APIs are quite flakey and may 500 for an
|
||||
# extended period of time. Trying to combat here by adding a very
|
||||
# long retry period (~20 minutes of trying every minute)
|
||||
add_retries = retry_builder(tries=50, max_delay=30)
|
||||
|
||||
|
||||
def _convert_single_file(
|
||||
creds: Any, primary_admin_email: str, file: dict[str, Any]
|
||||
) -> Any:
|
||||
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
|
||||
user_drive_service = get_drive_service(creds, user_email=user_email)
|
||||
docs_service = get_google_docs_service(creds, user_email=user_email)
|
||||
return convert_drive_item_to_document(
|
||||
file=file,
|
||||
drive_service=user_drive_service,
|
||||
docs_service=docs_service,
|
||||
def _run_drive_file_query(
|
||||
service: discovery.Resource,
|
||||
query: str,
|
||||
continue_on_failure: bool,
|
||||
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
|
||||
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
next_page_token = ""
|
||||
while next_page_token is not None:
|
||||
logger.debug(f"Running Google Drive fetch with query: {query}")
|
||||
results = add_retries(
|
||||
lambda: (
|
||||
service.files()
|
||||
.list(
|
||||
corpora="allDrives"
|
||||
if include_shared
|
||||
else "user", # needed to search through shared drives
|
||||
pageSize=batch_size,
|
||||
supportsAllDrives=include_shared,
|
||||
includeItemsFromAllDrives=include_shared,
|
||||
fields=(
|
||||
"nextPageToken, files(mimeType, id, name, permissions, "
|
||||
"modifiedTime, webViewLink, shortcutDetails)"
|
||||
),
|
||||
pageToken=next_page_token,
|
||||
q=query,
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
)()
|
||||
next_page_token = results.get("nextPageToken")
|
||||
files = results["files"]
|
||||
for file in files:
|
||||
if follow_shortcuts and "shortcutDetails" in file:
|
||||
try:
|
||||
file_shortcut_points_to = add_retries(
|
||||
lambda: (
|
||||
service.files()
|
||||
.get(
|
||||
fileId=file["shortcutDetails"]["targetId"],
|
||||
supportsAllDrives=include_shared,
|
||||
fields="mimeType, id, name, modifiedTime, webViewLink, permissions, shortcutDetails",
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
)()
|
||||
yield file_shortcut_points_to
|
||||
except HttpError:
|
||||
logger.error(
|
||||
f"Failed to follow shortcut with details: {file['shortcutDetails']}"
|
||||
)
|
||||
if continue_on_failure:
|
||||
continue
|
||||
raise
|
||||
else:
|
||||
yield file
|
||||
|
||||
|
||||
def _get_folder_id(
|
||||
service: discovery.Resource,
|
||||
parent_id: str,
|
||||
folder_name: str,
|
||||
include_shared: bool,
|
||||
follow_shortcuts: bool,
|
||||
) -> str | None:
|
||||
"""
|
||||
Get the ID of a folder given its name and the ID of its parent folder.
|
||||
"""
|
||||
query = f"'{parent_id}' in parents and name='{folder_name}' and "
|
||||
if follow_shortcuts:
|
||||
query += f"(mimeType='{DRIVE_FOLDER_TYPE}' or mimeType='{DRIVE_SHORTCUT_TYPE}')"
|
||||
else:
|
||||
query += f"mimeType='{DRIVE_FOLDER_TYPE}'"
|
||||
|
||||
# TODO: support specifying folder path in shared drive rather than just `My Drive`
|
||||
results = add_retries(
|
||||
lambda: (
|
||||
service.files()
|
||||
.list(
|
||||
q=query,
|
||||
spaces="drive",
|
||||
fields="nextPageToken, files(id, name, shortcutDetails)",
|
||||
supportsAllDrives=include_shared,
|
||||
includeItemsFromAllDrives=include_shared,
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
)()
|
||||
items = results.get("files", [])
|
||||
|
||||
folder_id = None
|
||||
if items:
|
||||
if follow_shortcuts and "shortcutDetails" in items[0]:
|
||||
folder_id = items[0]["shortcutDetails"]["targetId"]
|
||||
else:
|
||||
folder_id = items[0]["id"]
|
||||
return folder_id
|
||||
|
||||
|
||||
def _get_folders(
|
||||
service: discovery.Resource,
|
||||
continue_on_failure: bool,
|
||||
folder_id: str | None = None, # if specified, only fetches files within this folder
|
||||
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
|
||||
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
query = f"mimeType = '{DRIVE_FOLDER_TYPE}' "
|
||||
if follow_shortcuts:
|
||||
query = "(" + query + f" or mimeType = '{DRIVE_SHORTCUT_TYPE}'" + ") "
|
||||
|
||||
if folder_id:
|
||||
query += f"and '{folder_id}' in parents "
|
||||
query = query.rstrip() # remove the trailing space(s)
|
||||
|
||||
for file in _run_drive_file_query(
|
||||
service=service,
|
||||
query=query,
|
||||
continue_on_failure=continue_on_failure,
|
||||
include_shared=include_shared,
|
||||
follow_shortcuts=follow_shortcuts,
|
||||
batch_size=batch_size,
|
||||
):
|
||||
# Need to check this since file may have been a target of a shortcut
|
||||
# and not necessarily a folder
|
||||
if file["mimeType"] == DRIVE_FOLDER_TYPE:
|
||||
yield file
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
def _get_files(
|
||||
service: discovery.Resource,
|
||||
continue_on_failure: bool,
|
||||
time_range_start: SecondsSinceUnixEpoch | None = None,
|
||||
time_range_end: SecondsSinceUnixEpoch | None = None,
|
||||
folder_id: str | None = None, # if specified, only fetches files within this folder
|
||||
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
|
||||
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' "
|
||||
if time_range_start is not None:
|
||||
time_start = datetime.utcfromtimestamp(time_range_start).isoformat() + "Z"
|
||||
query += f"and modifiedTime >= '{time_start}' "
|
||||
if time_range_end is not None:
|
||||
time_stop = datetime.utcfromtimestamp(time_range_end).isoformat() + "Z"
|
||||
query += f"and modifiedTime <= '{time_stop}' "
|
||||
if folder_id:
|
||||
query += f"and '{folder_id}' in parents "
|
||||
query = query.rstrip() # remove the trailing space(s)
|
||||
|
||||
files = _run_drive_file_query(
|
||||
service=service,
|
||||
query=query,
|
||||
continue_on_failure=continue_on_failure,
|
||||
include_shared=include_shared,
|
||||
follow_shortcuts=follow_shortcuts,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
|
||||
def _process_files_batch(
|
||||
files: list[GoogleDriveFileType], convert_func: Callable, batch_size: int
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch = []
|
||||
with ThreadPoolExecutor(max_workers=min(16, len(files))) as executor:
|
||||
for doc in executor.map(convert_func, files):
|
||||
if doc:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
return files
|
||||
|
||||
|
||||
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def get_all_files_batched(
|
||||
service: discovery.Resource,
|
||||
continue_on_failure: bool,
|
||||
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
|
||||
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
time_range_start: SecondsSinceUnixEpoch | None = None,
|
||||
time_range_end: SecondsSinceUnixEpoch | None = None,
|
||||
folder_id: str | None = None, # if specified, only fetches files within this folder
|
||||
# if True, will fetch files in sub-folders of the specified folder ID.
|
||||
# Only applies if folder_id is specified.
|
||||
traverse_subfolders: bool = True,
|
||||
folder_ids_traversed: list[str] | None = None,
|
||||
) -> Iterator[list[GoogleDriveFileType]]:
|
||||
"""Gets all files matching the criteria specified by the args from Google Drive
|
||||
in batches of size `batch_size`.
|
||||
"""
|
||||
found_files = _get_files(
|
||||
service=service,
|
||||
continue_on_failure=continue_on_failure,
|
||||
time_range_start=time_range_start,
|
||||
time_range_end=time_range_end,
|
||||
folder_id=folder_id,
|
||||
include_shared=include_shared,
|
||||
follow_shortcuts=follow_shortcuts,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
yield from batch_generator(
|
||||
items=found_files,
|
||||
batch_size=batch_size,
|
||||
pre_batch_yield=lambda batch_files: logger.debug(
|
||||
f"Parseable Documents in batch: {[file['name'] for file in batch_files]}"
|
||||
),
|
||||
)
|
||||
|
||||
if traverse_subfolders and folder_id is not None:
|
||||
folder_ids_traversed = folder_ids_traversed or []
|
||||
subfolders = _get_folders(
|
||||
service=service,
|
||||
folder_id=folder_id,
|
||||
continue_on_failure=continue_on_failure,
|
||||
include_shared=include_shared,
|
||||
follow_shortcuts=follow_shortcuts,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
for subfolder in subfolders:
|
||||
if subfolder["id"] not in folder_ids_traversed:
|
||||
logger.info("Fetching all files in subfolder: " + subfolder["name"])
|
||||
folder_ids_traversed.append(subfolder["id"])
|
||||
yield from get_all_files_batched(
|
||||
service=service,
|
||||
continue_on_failure=continue_on_failure,
|
||||
include_shared=include_shared,
|
||||
follow_shortcuts=follow_shortcuts,
|
||||
batch_size=batch_size,
|
||||
time_range_start=time_range_start,
|
||||
time_range_end=time_range_end,
|
||||
folder_id=subfolder["id"],
|
||||
traverse_subfolders=traverse_subfolders,
|
||||
folder_ids_traversed=folder_ids_traversed,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Skipping subfolder since already traversed: " + subfolder["name"]
|
||||
)
|
||||
|
||||
|
||||
def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
|
||||
mime_type = file["mimeType"]
|
||||
|
||||
if mime_type not in set(item.value for item in GDriveMimeType):
|
||||
# Unsupported file types can still have a title, finding this way is still useful
|
||||
return UNSUPPORTED_FILE_TYPE_CONTENT
|
||||
|
||||
if mime_type in [
|
||||
GDriveMimeType.DOC.value,
|
||||
GDriveMimeType.PPT.value,
|
||||
GDriveMimeType.SPREADSHEET.value,
|
||||
]:
|
||||
export_mime_type = (
|
||||
"text/plain"
|
||||
if mime_type != GDriveMimeType.SPREADSHEET.value
|
||||
else "text/csv"
|
||||
)
|
||||
return (
|
||||
service.files()
|
||||
.export(fileId=file["id"], mimeType=export_mime_type)
|
||||
.execute()
|
||||
.decode("utf-8")
|
||||
)
|
||||
elif mime_type in [
|
||||
GDriveMimeType.PLAIN_TEXT.value,
|
||||
GDriveMimeType.MARKDOWN.value,
|
||||
]:
|
||||
return service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
|
||||
if mime_type in [
|
||||
GDriveMimeType.WORD_DOC.value,
|
||||
GDriveMimeType.POWERPOINT.value,
|
||||
GDriveMimeType.PDF.value,
|
||||
]:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
if get_unstructured_api_key():
|
||||
return unstructured_to_text(
|
||||
file=io.BytesIO(response), file_name=file.get("name", file["id"])
|
||||
)
|
||||
|
||||
if mime_type == GDriveMimeType.WORD_DOC.value:
|
||||
return docx_to_text(file=io.BytesIO(response))
|
||||
elif mime_type == GDriveMimeType.PDF.value:
|
||||
text, _ = read_pdf_file(file=io.BytesIO(response))
|
||||
return text
|
||||
elif mime_type == GDriveMimeType.POWERPOINT.value:
|
||||
return pptx_to_text(file=io.BytesIO(response))
|
||||
|
||||
return UNSUPPORTED_FILE_TYPE_CONTENT
|
||||
|
||||
|
||||
class GoogleDriveConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
include_shared_drives: bool = True,
|
||||
shared_drive_urls: str | None = None,
|
||||
include_my_drives: bool = True,
|
||||
my_drive_emails: str | None = None,
|
||||
shared_folder_urls: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
# OLD PARAMETERS
|
||||
# optional list of folder paths e.g. "[My Folder/My Subfolder]"
|
||||
# if specified, will only index files in these folders
|
||||
folder_paths: list[str] | None = None,
|
||||
include_shared: bool | None = None,
|
||||
follow_shortcuts: bool | None = None,
|
||||
only_org_public: bool | None = None,
|
||||
continue_on_failure: bool | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
|
||||
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
|
||||
only_org_public: bool = GOOGLE_DRIVE_ONLY_ORG_PUBLIC,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
) -> None:
|
||||
# Check for old input parameters
|
||||
if (
|
||||
folder_paths is not None
|
||||
or include_shared is not None
|
||||
or follow_shortcuts is not None
|
||||
or only_org_public is not None
|
||||
or continue_on_failure is not None
|
||||
):
|
||||
logger.exception(
|
||||
"Google Drive connector received old input parameters. "
|
||||
"Please visit the docs for help with the new setup: "
|
||||
f"{SCOPE_DOC_URL}"
|
||||
)
|
||||
raise ValueError(
|
||||
"Google Drive connector received old input parameters. "
|
||||
"Please visit the docs for help with the new setup: "
|
||||
f"{SCOPE_DOC_URL}"
|
||||
)
|
||||
|
||||
if (
|
||||
not include_shared_drives
|
||||
and not include_my_drives
|
||||
and not shared_folder_urls
|
||||
):
|
||||
raise ValueError(
|
||||
"At least one of include_shared_drives, include_my_drives,"
|
||||
" or shared_folder_urls must be true"
|
||||
)
|
||||
|
||||
self.folder_paths = folder_paths or []
|
||||
self.batch_size = batch_size
|
||||
self.include_shared = include_shared
|
||||
self.follow_shortcuts = follow_shortcuts
|
||||
self.only_org_public = only_org_public
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
|
||||
self.include_shared_drives = include_shared_drives
|
||||
shared_drive_url_list = _extract_str_list_from_comma_str(shared_drive_urls)
|
||||
self._requested_shared_drive_ids = set(
|
||||
_extract_ids_from_urls(shared_drive_url_list)
|
||||
)
|
||||
@staticmethod
|
||||
def _process_folder_paths(
|
||||
service: discovery.Resource,
|
||||
folder_paths: list[str],
|
||||
include_shared: bool,
|
||||
follow_shortcuts: bool,
|
||||
) -> list[str]:
|
||||
"""['Folder/Sub Folder'] -> ['<FOLDER_ID>']"""
|
||||
folder_ids: list[str] = []
|
||||
for path in folder_paths:
|
||||
folder_names = path.split("/")
|
||||
parent_id = "root"
|
||||
for folder_name in folder_names:
|
||||
found_parent_id = _get_folder_id(
|
||||
service=service,
|
||||
parent_id=parent_id,
|
||||
folder_name=folder_name,
|
||||
include_shared=include_shared,
|
||||
follow_shortcuts=follow_shortcuts,
|
||||
)
|
||||
if found_parent_id is None:
|
||||
raise ValueError(
|
||||
(
|
||||
f"Folder '{folder_name}' in path '{path}' "
|
||||
"not found in Google Drive"
|
||||
)
|
||||
)
|
||||
parent_id = found_parent_id
|
||||
folder_ids.append(parent_id)
|
||||
|
||||
self.include_my_drives = include_my_drives
|
||||
self._requested_my_drive_emails = set(
|
||||
_extract_str_list_from_comma_str(my_drive_emails)
|
||||
)
|
||||
|
||||
shared_folder_url_list = _extract_str_list_from_comma_str(shared_folder_urls)
|
||||
self._requested_folder_ids = set(_extract_ids_from_urls(shared_folder_url_list))
|
||||
|
||||
self._primary_admin_email: str | None = None
|
||||
|
||||
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
|
||||
self._retrieved_ids: set[str] = set()
|
||||
|
||||
@property
|
||||
def primary_admin_email(self) -> str:
|
||||
if self._primary_admin_email is None:
|
||||
raise RuntimeError(
|
||||
"Primary admin email missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
return self._primary_admin_email
|
||||
|
||||
@property
|
||||
def google_domain(self) -> str:
|
||||
if self._primary_admin_email is None:
|
||||
raise RuntimeError(
|
||||
"Primary admin email missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
return self._primary_admin_email.split("@")[-1]
|
||||
|
||||
@property
|
||||
def creds(self) -> OAuthCredentials | ServiceAccountCredentials:
|
||||
if self._creds is None:
|
||||
raise RuntimeError(
|
||||
"Creds missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
return self._creds
|
||||
return folder_ids
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
|
||||
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
|
||||
self._primary_admin_email = primary_admin_email
|
||||
|
||||
self._creds, new_creds_dict = get_google_creds(
|
||||
credentials=credentials,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
)
|
||||
"""Checks for two different types of credentials.
|
||||
(1) A credential which holds a token acquired via a user going thorough
|
||||
the Google OAuth flow.
|
||||
(2) A credential which holds a service account key JSON file, which
|
||||
can then be used to impersonate any user in the workspace.
|
||||
"""
|
||||
creds, new_creds_dict = get_google_drive_creds(credentials)
|
||||
self.creds = creds
|
||||
return new_creds_dict
|
||||
|
||||
def _update_traversed_parent_ids(self, folder_id: str) -> None:
|
||||
self._retrieved_ids.add(folder_id)
|
||||
|
||||
def _get_all_user_emails(self, admins_only: bool) -> list[str]:
|
||||
admin_service = get_admin_service(
|
||||
creds=self.creds,
|
||||
user_email=self.primary_admin_email,
|
||||
)
|
||||
query = "isAdmin=true" if admins_only else "isAdmin=false"
|
||||
emails = []
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
list_key="users",
|
||||
fields=USER_FIELDS,
|
||||
domain=self.google_domain,
|
||||
query=query,
|
||||
):
|
||||
if email := user.get("primaryEmail"):
|
||||
emails.append(email)
|
||||
return emails
|
||||
|
||||
def _get_all_drive_ids(self) -> set[str]:
|
||||
primary_drive_service = get_drive_service(
|
||||
creds=self.creds,
|
||||
user_email=self.primary_admin_email,
|
||||
)
|
||||
all_drive_ids = set()
|
||||
for drive in execute_paginated_retrieval(
|
||||
retrieval_function=primary_drive_service.drives().list,
|
||||
list_key="drives",
|
||||
useDomainAdminAccess=True,
|
||||
fields="drives(id)",
|
||||
):
|
||||
all_drive_ids.add(drive["id"])
|
||||
return all_drive_ids
|
||||
|
||||
def _initialize_all_class_variables(self) -> None:
|
||||
# Get all user emails
|
||||
# Get admins first becuase they are more likely to have access to the most files
|
||||
user_emails = [self.primary_admin_email]
|
||||
for admins_only in [True, False]:
|
||||
for email in self._get_all_user_emails(admins_only=admins_only):
|
||||
if email not in user_emails:
|
||||
user_emails.append(email)
|
||||
self._all_org_emails = user_emails
|
||||
|
||||
self._all_drive_ids: set[str] = self._get_all_drive_ids()
|
||||
|
||||
# remove drive ids from the folder ids because they are queried differently
|
||||
self._requested_folder_ids -= self._all_drive_ids
|
||||
|
||||
# Remove drive_ids that are not in the all_drive_ids and check them as folders instead
|
||||
invalid_drive_ids = self._requested_shared_drive_ids - self._all_drive_ids
|
||||
if invalid_drive_ids:
|
||||
logger.warning(
|
||||
f"Some shared drive IDs were not found. IDs: {invalid_drive_ids}"
|
||||
)
|
||||
logger.warning("Checking for folder access instead...")
|
||||
self._requested_folder_ids.update(invalid_drive_ids)
|
||||
|
||||
if not self.include_shared_drives:
|
||||
self._requested_shared_drive_ids = set()
|
||||
elif not self._requested_shared_drive_ids:
|
||||
self._requested_shared_drive_ids = self._all_drive_ids
|
||||
|
||||
def _impersonate_user_for_retrieval(
|
||||
self,
|
||||
user_email: str,
|
||||
is_slim: bool,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
drive_service = get_drive_service(self.creds, user_email)
|
||||
if self.include_my_drives and (
|
||||
not self._requested_my_drive_emails
|
||||
or user_email in self._requested_my_drive_emails
|
||||
):
|
||||
yield from get_all_files_in_my_drive(
|
||||
service=drive_service,
|
||||
update_traversed_ids_func=self._update_traversed_parent_ids,
|
||||
is_slim=is_slim,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
remaining_drive_ids = self._requested_shared_drive_ids - self._retrieved_ids
|
||||
for drive_id in remaining_drive_ids:
|
||||
yield from get_files_in_shared_drive(
|
||||
service=drive_service,
|
||||
drive_id=drive_id,
|
||||
is_slim=is_slim,
|
||||
update_traversed_ids_func=self._update_traversed_parent_ids,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
remaining_folders = self._requested_folder_ids - self._retrieved_ids
|
||||
for folder_id in remaining_folders:
|
||||
yield from crawl_folders_for_files(
|
||||
service=drive_service,
|
||||
parent_id=folder_id,
|
||||
traversed_parent_ids=self._retrieved_ids,
|
||||
update_traversed_ids_func=self._update_traversed_parent_ids,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
def _fetch_drive_items(
|
||||
self,
|
||||
is_slim: bool,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
self._initialize_all_class_variables()
|
||||
|
||||
# Process users in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
future_to_email = {
|
||||
executor.submit(
|
||||
self._impersonate_user_for_retrieval, email, is_slim, start, end
|
||||
): email
|
||||
for email in self._all_org_emails
|
||||
}
|
||||
|
||||
# Yield results as they complete
|
||||
for future in as_completed(future_to_email):
|
||||
yield from future.result()
|
||||
|
||||
remaining_folders = self._requested_folder_ids - self._retrieved_ids
|
||||
if remaining_folders:
|
||||
logger.warning(
|
||||
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"
|
||||
)
|
||||
|
||||
def _extract_docs_from_google_drive(
|
||||
def _fetch_docs_from_drive(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
# Create a larger process pool for file conversion
|
||||
convert_func = partial(
|
||||
_convert_single_file, self.creds, self.primary_admin_email
|
||||
if self.creds is None:
|
||||
raise PermissionError("Not logged into Google Drive")
|
||||
|
||||
service = discovery.build("drive", "v3", credentials=self.creds)
|
||||
folder_ids: Sequence[str | None] = self._process_folder_paths(
|
||||
service, self.folder_paths, self.include_shared, self.follow_shortcuts
|
||||
)
|
||||
if not folder_ids:
|
||||
folder_ids = [None]
|
||||
|
||||
# Process files in larger batches
|
||||
LARGE_BATCH_SIZE = self.batch_size * 4
|
||||
files_to_process = []
|
||||
# Gather the files into batches to be processed in parallel
|
||||
for file in self._fetch_drive_items(is_slim=False, start=start, end=end):
|
||||
files_to_process.append(file)
|
||||
if len(files_to_process) >= LARGE_BATCH_SIZE:
|
||||
yield from _process_files_batch(
|
||||
files_to_process, convert_func, self.batch_size
|
||||
file_batches = chain(
|
||||
*[
|
||||
get_all_files_batched(
|
||||
service=service,
|
||||
continue_on_failure=self.continue_on_failure,
|
||||
include_shared=self.include_shared,
|
||||
follow_shortcuts=self.follow_shortcuts,
|
||||
batch_size=self.batch_size,
|
||||
time_range_start=start,
|
||||
time_range_end=end,
|
||||
folder_id=folder_id,
|
||||
traverse_subfolders=True,
|
||||
)
|
||||
files_to_process = []
|
||||
for folder_id in folder_ids
|
||||
]
|
||||
)
|
||||
for files_batch in file_batches:
|
||||
doc_batch = []
|
||||
for file in files_batch:
|
||||
try:
|
||||
# Skip files that are shortcuts
|
||||
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
|
||||
logger.info("Ignoring Drive Shortcut Filetype")
|
||||
continue
|
||||
|
||||
# Process any remaining files
|
||||
if files_to_process:
|
||||
yield from _process_files_batch(
|
||||
files_to_process, convert_func, self.batch_size
|
||||
)
|
||||
if self.only_org_public:
|
||||
if "permissions" not in file:
|
||||
continue
|
||||
if not any(
|
||||
permission["type"] == "domain"
|
||||
for permission in file["permissions"]
|
||||
):
|
||||
continue
|
||||
try:
|
||||
text_contents = extract_text(file, service) or ""
|
||||
except HttpError as e:
|
||||
reason = (
|
||||
e.error_details[0]["reason"]
|
||||
if e.error_details
|
||||
else e.reason
|
||||
)
|
||||
message = (
|
||||
e.error_details[0]["message"]
|
||||
if e.error_details
|
||||
else e.reason
|
||||
)
|
||||
|
||||
# these errors don't represent a failure in the connector, but simply files
|
||||
# that can't / shouldn't be indexed
|
||||
ERRORS_TO_CONTINUE_ON = [
|
||||
"cannotExportFile",
|
||||
"exportSizeLimitExceeded",
|
||||
"cannotDownloadFile",
|
||||
]
|
||||
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:
|
||||
logger.warning(
|
||||
f"Could not export file '{file['name']}' due to '{message}', skipping..."
|
||||
)
|
||||
continue
|
||||
|
||||
raise
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=file["webViewLink"],
|
||||
sections=[
|
||||
Section(link=file["webViewLink"], text=text_contents)
|
||||
],
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
semantic_identifier=file["name"],
|
||||
doc_updated_at=datetime.fromisoformat(
|
||||
file["modifiedTime"]
|
||||
).astimezone(timezone.utc),
|
||||
metadata={} if text_contents else {IGNORE_FOR_QA: "True"},
|
||||
additional_info=file.get("id"),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Ran into exception when pulling a file from Google Drive"
|
||||
)
|
||||
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
try:
|
||||
yield from self._extract_docs_from_google_drive()
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
yield from self._fetch_docs_from_drive()
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
try:
|
||||
yield from self._extract_docs_from_google_drive(start, end)
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
# need to subtract 10 minutes from start time to account for modifiedTime
|
||||
# propogation if a document is modified, it takes some time for the API to
|
||||
# reflect these changes if we do not have an offset, then we may "miss" the
|
||||
# update when polling
|
||||
yield from self._fetch_docs_from_drive(start, end)
|
||||
|
||||
def _extract_slim_docs_from_google_drive(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
slim_batch = []
|
||||
for file in self._fetch_drive_items(
|
||||
is_slim=True,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
if doc := build_slim_document(file):
|
||||
slim_batch.append(doc)
|
||||
if len(slim_batch) >= SLIM_BATCH_SIZE:
|
||||
yield slim_batch
|
||||
slim_batch = []
|
||||
yield slim_batch
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
try:
|
||||
yield from self._extract_slim_docs_from_google_drive(start, end)
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
if __name__ == "__main__":
|
||||
import json
|
||||
import os
|
||||
|
||||
service_account_json_path = os.environ.get("GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH")
|
||||
if not service_account_json_path:
|
||||
raise ValueError(
|
||||
"Please set GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH environment variable"
|
||||
)
|
||||
with open(service_account_json_path) as f:
|
||||
creds = json.load(f)
|
||||
|
||||
credentials_dict = {
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: json.dumps(creds),
|
||||
}
|
||||
delegated_user = os.environ.get("GOOGLE_DRIVE_DELEGATED_USER")
|
||||
if delegated_user:
|
||||
credentials_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user
|
||||
|
||||
connector = GoogleDriveConnector(include_shared=True, follow_shortcuts=True)
|
||||
connector.load_credentials(credentials_dict)
|
||||
document_batch_generator = connector.load_from_state()
|
||||
for document_batch in document_batch_generator:
|
||||
print(document_batch)
|
||||
break
|
||||
|
||||
229
backend/danswer/connectors/google_drive/connector_auth.py
Normal file
229
backend/danswer/connectors/google_drive/connector_auth.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import json
|
||||
from typing import cast
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import ParseResult
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from google.auth.transport.requests import Request # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import KV_CRED_KEY
|
||||
from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
|
||||
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
from danswer.connectors.google_drive.constants import BASE_SCOPES
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
|
||||
from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES
|
||||
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
|
||||
from danswer.db.credentials import update_credential_json
|
||||
from danswer.db.models import User
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.server.documents.models import GoogleAppCredentials
|
||||
from danswer.server.documents.models import GoogleServiceAccountKey
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_gdrive_scopes() -> list[str]:
|
||||
base_scopes: list[str] = BASE_SCOPES
|
||||
permissions_scopes: list[str] = FETCH_PERMISSIONS_SCOPES
|
||||
groups_scopes: list[str] = FETCH_GROUPS_SCOPES
|
||||
|
||||
if ENTERPRISE_EDITION_ENABLED:
|
||||
return base_scopes + permissions_scopes + groups_scopes
|
||||
return base_scopes + permissions_scopes
|
||||
|
||||
|
||||
def _build_frontend_google_drive_redirect() -> str:
|
||||
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
|
||||
|
||||
|
||||
def get_google_drive_creds_for_authorized_user(
|
||||
token_json_str: str, scopes: list[str] = build_gdrive_scopes()
|
||||
) -> OAuthCredentials | None:
|
||||
creds_json = json.loads(token_json_str)
|
||||
creds = OAuthCredentials.from_authorized_user_info(creds_json, scopes)
|
||||
if creds.valid:
|
||||
return creds
|
||||
|
||||
if creds.expired and creds.refresh_token:
|
||||
try:
|
||||
creds.refresh(Request())
|
||||
if creds.valid:
|
||||
logger.notice("Refreshed Google Drive tokens.")
|
||||
return creds
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to refresh google drive access token due to: {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_google_drive_creds_for_service_account(
|
||||
service_account_key_json_str: str, scopes: list[str] = build_gdrive_scopes()
|
||||
) -> ServiceAccountCredentials | None:
|
||||
service_account_key = json.loads(service_account_key_json_str)
|
||||
creds = ServiceAccountCredentials.from_service_account_info(
|
||||
service_account_key, scopes=scopes
|
||||
)
|
||||
if not creds.valid or not creds.expired:
|
||||
creds.refresh(Request())
|
||||
return creds if creds.valid else None
|
||||
|
||||
|
||||
def get_google_drive_creds(
|
||||
credentials: dict[str, str], scopes: list[str] = build_gdrive_scopes()
|
||||
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
|
||||
oauth_creds = None
|
||||
service_creds = None
|
||||
new_creds_dict = None
|
||||
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
|
||||
access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY])
|
||||
oauth_creds = get_google_drive_creds_for_authorized_user(
|
||||
token_json_str=access_token_json_str, scopes=scopes
|
||||
)
|
||||
|
||||
# tell caller to update token stored in DB if it has changed
|
||||
# (e.g. the token has been refreshed)
|
||||
new_creds_json_str = oauth_creds.to_json() if oauth_creds else ""
|
||||
if new_creds_json_str != access_token_json_str:
|
||||
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
|
||||
|
||||
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
|
||||
service_account_key_json_str = credentials[
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
|
||||
]
|
||||
service_creds = _get_google_drive_creds_for_service_account(
|
||||
service_account_key_json_str=service_account_key_json_str,
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
# "Impersonate" a user if one is specified
|
||||
delegated_user_email = cast(
|
||||
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
|
||||
)
|
||||
if delegated_user_email:
|
||||
service_creds = (
|
||||
service_creds.with_subject(delegated_user_email)
|
||||
if service_creds
|
||||
else None
|
||||
)
|
||||
|
||||
creds: ServiceAccountCredentials | OAuthCredentials | None = (
|
||||
oauth_creds or service_creds
|
||||
)
|
||||
if creds is None:
|
||||
raise PermissionError(
|
||||
"Unable to access Google Drive - unknown credential structure."
|
||||
)
|
||||
|
||||
return creds, new_creds_dict
|
||||
|
||||
|
||||
def verify_csrf(credential_id: int, state: str) -> None:
|
||||
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
|
||||
if csrf != state:
|
||||
raise PermissionError(
|
||||
"State from Google Drive Connector callback does not match expected"
|
||||
)
|
||||
|
||||
|
||||
def get_auth_url(credential_id: int) -> str:
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
scopes=build_gdrive_scopes(),
|
||||
redirect_uri=_build_frontend_google_drive_redirect(),
|
||||
)
|
||||
auth_url, _ = flow.authorization_url(prompt="consent")
|
||||
|
||||
parsed_url = cast(ParseResult, urlparse(auth_url))
|
||||
params = parse_qs(parsed_url.query)
|
||||
|
||||
get_kv_store().store(
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
) # type: ignore
|
||||
return str(auth_url)
|
||||
|
||||
|
||||
def update_credential_access_tokens(
|
||||
auth_code: str,
|
||||
credential_id: int,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> OAuthCredentials | None:
|
||||
app_credentials = get_google_app_cred()
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
app_credentials.model_dump(),
|
||||
scopes=build_gdrive_scopes(),
|
||||
redirect_uri=_build_frontend_google_drive_redirect(),
|
||||
)
|
||||
flow.fetch_token(code=auth_code)
|
||||
creds = flow.credentials
|
||||
token_json_str = creds.to_json()
|
||||
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str}
|
||||
|
||||
if not update_credential_json(credential_id, new_creds_dict, user, db_session):
|
||||
return None
|
||||
return creds
|
||||
|
||||
|
||||
def build_service_account_creds(
|
||||
source: DocumentSource,
|
||||
delegated_user_email: str | None = None,
|
||||
) -> CredentialBase:
|
||||
service_account_key = get_service_account_key()
|
||||
|
||||
credential_dict = {
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: service_account_key.json(),
|
||||
}
|
||||
if delegated_user_email:
|
||||
credential_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user_email
|
||||
|
||||
return CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
)
|
||||
|
||||
|
||||
def get_google_app_cred() -> GoogleAppCredentials:
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
return GoogleAppCredentials(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_google_app_cred(app_credentials: GoogleAppCredentials) -> None:
|
||||
get_kv_store().store(KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True)
|
||||
|
||||
|
||||
def delete_google_app_cred() -> None:
|
||||
get_kv_store().delete(KV_GOOGLE_DRIVE_CRED_KEY)
|
||||
|
||||
|
||||
def get_service_account_key() -> GoogleServiceAccountKey:
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
|
||||
return GoogleServiceAccountKey(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
|
||||
get_kv_store().store(
|
||||
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
)
|
||||
|
||||
|
||||
def delete_service_account_key() -> None:
|
||||
get_kv_store().delete(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
|
||||
@@ -1,4 +1,7 @@
|
||||
UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now
|
||||
DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder"
|
||||
DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut"
|
||||
DRIVE_FILE_TYPE = "application/vnd.google-apps.file"
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens"
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user"
|
||||
|
||||
BASE_SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
|
||||
FETCH_PERMISSIONS_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"]
|
||||
FETCH_GROUPS_SCOPES = ["https://www.googleapis.com/auth/cloud-identity.groups.readonly"]
|
||||
|
||||
@@ -1,197 +0,0 @@
|
||||
import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
|
||||
from danswer.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
|
||||
from danswer.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT
|
||||
from danswer.connectors.google_drive.models import GDriveMimeType
|
||||
from danswer.connectors.google_drive.models import GoogleDriveFileType
|
||||
from danswer.connectors.google_drive.section_extraction import get_document_sections
|
||||
from danswer.connectors.google_utils.resources import GoogleDocsService
|
||||
from danswer.connectors.google_utils.resources import GoogleDriveService
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.models import SlimDocument
|
||||
from danswer.file_processing.extract_file_text import docx_to_text
|
||||
from danswer.file_processing.extract_file_text import pptx_to_text
|
||||
from danswer.file_processing.extract_file_text import read_pdf_file
|
||||
from danswer.file_processing.unstructured import get_unstructured_api_key
|
||||
from danswer.file_processing.unstructured import unstructured_to_text
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# these errors don't represent a failure in the connector, but simply files
|
||||
# that can't / shouldn't be indexed
|
||||
ERRORS_TO_CONTINUE_ON = [
|
||||
"cannotExportFile",
|
||||
"exportSizeLimitExceeded",
|
||||
"cannotDownloadFile",
|
||||
]
|
||||
|
||||
|
||||
def _extract_sections_basic(
|
||||
file: dict[str, str], service: GoogleDriveService
|
||||
) -> list[Section]:
|
||||
mime_type = file["mimeType"]
|
||||
link = file["webViewLink"]
|
||||
|
||||
if mime_type not in set(item.value for item in GDriveMimeType):
|
||||
# Unsupported file types can still have a title, finding this way is still useful
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
|
||||
try:
|
||||
if mime_type in [
|
||||
GDriveMimeType.DOC.value,
|
||||
GDriveMimeType.PPT.value,
|
||||
GDriveMimeType.SPREADSHEET.value,
|
||||
]:
|
||||
export_mime_type = (
|
||||
"text/plain"
|
||||
if mime_type != GDriveMimeType.SPREADSHEET.value
|
||||
else "text/csv"
|
||||
)
|
||||
text = (
|
||||
service.files()
|
||||
.export(fileId=file["id"], mimeType=export_mime_type)
|
||||
.execute()
|
||||
.decode("utf-8")
|
||||
)
|
||||
return [Section(link=link, text=text)]
|
||||
elif mime_type in [
|
||||
GDriveMimeType.PLAIN_TEXT.value,
|
||||
GDriveMimeType.MARKDOWN.value,
|
||||
]:
|
||||
return [
|
||||
Section(
|
||||
link=link,
|
||||
text=service.files()
|
||||
.get_media(fileId=file["id"])
|
||||
.execute()
|
||||
.decode("utf-8"),
|
||||
)
|
||||
]
|
||||
if mime_type in [
|
||||
GDriveMimeType.WORD_DOC.value,
|
||||
GDriveMimeType.POWERPOINT.value,
|
||||
GDriveMimeType.PDF.value,
|
||||
]:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
if get_unstructured_api_key():
|
||||
return [
|
||||
Section(
|
||||
link=link,
|
||||
text=unstructured_to_text(
|
||||
file=io.BytesIO(response),
|
||||
file_name=file.get("name", file["id"]),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
if mime_type == GDriveMimeType.WORD_DOC.value:
|
||||
return [
|
||||
Section(link=link, text=docx_to_text(file=io.BytesIO(response)))
|
||||
]
|
||||
elif mime_type == GDriveMimeType.PDF.value:
|
||||
text, _ = read_pdf_file(file=io.BytesIO(response))
|
||||
return [Section(link=link, text=text)]
|
||||
elif mime_type == GDriveMimeType.POWERPOINT.value:
|
||||
return [
|
||||
Section(link=link, text=pptx_to_text(file=io.BytesIO(response)))
|
||||
]
|
||||
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
|
||||
except Exception:
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
|
||||
|
||||
def convert_drive_item_to_document(
|
||||
file: GoogleDriveFileType,
|
||||
drive_service: GoogleDriveService,
|
||||
docs_service: GoogleDocsService,
|
||||
) -> Document | None:
|
||||
try:
|
||||
# Skip files that are shortcuts
|
||||
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
|
||||
logger.info("Ignoring Drive Shortcut Filetype")
|
||||
return None
|
||||
# Skip files that are folders
|
||||
if file.get("mimeType") == DRIVE_FOLDER_TYPE:
|
||||
logger.info("Ignoring Drive Folder Filetype")
|
||||
return None
|
||||
|
||||
sections: list[Section] = []
|
||||
|
||||
# Special handling for Google Docs to preserve structure, link
|
||||
# to headers
|
||||
if file.get("mimeType") == GDriveMimeType.DOC.value:
|
||||
try:
|
||||
sections = get_document_sections(docs_service, file["id"])
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Ran into exception '{e}' when pulling sections from Google Doc '{file['name']}'."
|
||||
" Falling back to basic extraction."
|
||||
)
|
||||
# NOTE: this will run for either (1) the above failed or (2) the file is not a Google Doc
|
||||
if not sections:
|
||||
try:
|
||||
# For all other file types just extract the text
|
||||
sections = _extract_sections_basic(file, drive_service)
|
||||
|
||||
except HttpError as e:
|
||||
reason = e.error_details[0]["reason"] if e.error_details else e.reason
|
||||
message = e.error_details[0]["message"] if e.error_details else e.reason
|
||||
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:
|
||||
logger.warning(
|
||||
f"Could not export file '{file['name']}' due to '{message}', skipping..."
|
||||
)
|
||||
return None
|
||||
|
||||
raise
|
||||
if not sections:
|
||||
return None
|
||||
|
||||
return Document(
|
||||
id=file["webViewLink"],
|
||||
sections=sections,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
semantic_identifier=file["name"],
|
||||
doc_updated_at=datetime.fromisoformat(file["modifiedTime"]).astimezone(
|
||||
timezone.utc
|
||||
),
|
||||
metadata={}
|
||||
if any(section.text for section in sections)
|
||||
else {IGNORE_FOR_QA: "True"},
|
||||
additional_info=file.get("id"),
|
||||
)
|
||||
except Exception as e:
|
||||
if not CONTINUE_ON_CONNECTOR_FAILURE:
|
||||
raise e
|
||||
|
||||
logger.exception("Ran into exception when pulling a file from Google Drive")
|
||||
return None
|
||||
|
||||
|
||||
def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None:
|
||||
# Skip files that are folders or shortcuts
|
||||
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
|
||||
return None
|
||||
|
||||
return SlimDocument(
|
||||
id=file["webViewLink"],
|
||||
perm_sync_data={
|
||||
"doc_id": file.get("id"),
|
||||
"permissions": file.get("permissions", []),
|
||||
"permission_ids": file.get("permissionIds", []),
|
||||
"name": file.get("name"),
|
||||
"owner_email": file.get("owners", [{}])[0].get("emailAddress"),
|
||||
},
|
||||
)
|
||||
@@ -1,222 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
|
||||
from danswer.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
|
||||
from danswer.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
|
||||
from danswer.connectors.google_drive.models import GoogleDriveFileType
|
||||
from danswer.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
FILE_FIELDS = (
|
||||
"nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, "
|
||||
"shortcutDetails, owners(emailAddress))"
|
||||
)
|
||||
SLIM_FILE_FIELDS = (
|
||||
"nextPageToken, files(mimeType, id, name, permissions(emailAddress, type), "
|
||||
"permissionIds, webViewLink, owners(emailAddress))"
|
||||
)
|
||||
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
|
||||
|
||||
|
||||
def _generate_time_range_filter(
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> str:
|
||||
time_range_filter = ""
|
||||
if start is not None:
|
||||
time_start = datetime.utcfromtimestamp(start).isoformat() + "Z"
|
||||
time_range_filter += f" and modifiedTime >= '{time_start}'"
|
||||
if end is not None:
|
||||
time_stop = datetime.utcfromtimestamp(end).isoformat() + "Z"
|
||||
time_range_filter += f" and modifiedTime <= '{time_stop}'"
|
||||
return time_range_filter
|
||||
|
||||
|
||||
def _get_folders_in_parent(
|
||||
service: Resource,
|
||||
parent_id: str | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
# Follow shortcuts to folders
|
||||
query = f"(mimeType = '{DRIVE_FOLDER_TYPE}' or mimeType = '{DRIVE_SHORTCUT_TYPE}')"
|
||||
query += " and trashed = false"
|
||||
|
||||
if parent_id:
|
||||
query += f" and '{parent_id}' in parents"
|
||||
|
||||
for file in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=True,
|
||||
corpora="allDrives",
|
||||
supportsAllDrives=True,
|
||||
includeItemsFromAllDrives=True,
|
||||
fields=FOLDER_FIELDS,
|
||||
q=query,
|
||||
):
|
||||
yield file
|
||||
|
||||
|
||||
def _get_files_in_parent(
|
||||
service: Resource,
|
||||
parent_id: str,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
is_slim: bool = False,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents"
|
||||
query += " and trashed = false"
|
||||
query += _generate_time_range_filter(start, end)
|
||||
|
||||
for file in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=True,
|
||||
corpora="allDrives",
|
||||
supportsAllDrives=True,
|
||||
includeItemsFromAllDrives=True,
|
||||
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
|
||||
q=query,
|
||||
):
|
||||
yield file
|
||||
|
||||
|
||||
def crawl_folders_for_files(
|
||||
service: Resource,
|
||||
parent_id: str,
|
||||
traversed_parent_ids: set[str],
|
||||
update_traversed_ids_func: Callable[[str], None],
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
"""
|
||||
This function starts crawling from any folder. It is slower though.
|
||||
"""
|
||||
if parent_id in traversed_parent_ids:
|
||||
logger.info(f"Skipping subfolder since already traversed: {parent_id}")
|
||||
return
|
||||
|
||||
found_files = False
|
||||
for file in _get_files_in_parent(
|
||||
service=service,
|
||||
start=start,
|
||||
end=end,
|
||||
parent_id=parent_id,
|
||||
):
|
||||
found_files = True
|
||||
yield file
|
||||
|
||||
if found_files:
|
||||
update_traversed_ids_func(parent_id)
|
||||
|
||||
for subfolder in _get_folders_in_parent(
|
||||
service=service,
|
||||
parent_id=parent_id,
|
||||
):
|
||||
logger.info("Fetching all files in subfolder: " + subfolder["name"])
|
||||
yield from crawl_folders_for_files(
|
||||
service=service,
|
||||
parent_id=subfolder["id"],
|
||||
traversed_parent_ids=traversed_parent_ids,
|
||||
update_traversed_ids_func=update_traversed_ids_func,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
|
||||
def get_files_in_shared_drive(
|
||||
service: Resource,
|
||||
drive_id: str,
|
||||
is_slim: bool = False,
|
||||
update_traversed_ids_func: Callable[[str], None] = lambda _: None,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
# If we know we are going to folder crawl later, we can cache the folders here
|
||||
# Get all folders being queried and add them to the traversed set
|
||||
query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
|
||||
query += " and trashed = false"
|
||||
found_folders = False
|
||||
for file in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=True,
|
||||
corpora="drive",
|
||||
driveId=drive_id,
|
||||
supportsAllDrives=True,
|
||||
includeItemsFromAllDrives=True,
|
||||
fields="nextPageToken, files(id)",
|
||||
q=query,
|
||||
):
|
||||
update_traversed_ids_func(file["id"])
|
||||
found_folders = True
|
||||
if found_folders:
|
||||
update_traversed_ids_func(drive_id)
|
||||
|
||||
# Get all files in the shared drive
|
||||
query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
|
||||
query += " and trashed = false"
|
||||
query += _generate_time_range_filter(start, end)
|
||||
yield from execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=True,
|
||||
corpora="drive",
|
||||
driveId=drive_id,
|
||||
supportsAllDrives=True,
|
||||
includeItemsFromAllDrives=True,
|
||||
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
|
||||
q=query,
|
||||
)
|
||||
|
||||
|
||||
def get_all_files_in_my_drive(
|
||||
service: Any,
|
||||
update_traversed_ids_func: Callable,
|
||||
is_slim: bool = False,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
# If we know we are going to folder crawl later, we can cache the folders here
|
||||
# Get all folders being queried and add them to the traversed set
|
||||
query = "trashed = false and 'me' in owners"
|
||||
found_folders = False
|
||||
for file in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
corpora="user",
|
||||
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
|
||||
q=query,
|
||||
):
|
||||
update_traversed_ids_func(file["id"])
|
||||
found_folders = True
|
||||
if found_folders:
|
||||
update_traversed_ids_func(get_root_folder_id(service))
|
||||
|
||||
# Then get the files
|
||||
query = "trashed = false and 'me' in owners"
|
||||
query += _generate_time_range_filter(start, end)
|
||||
fields = "files(id, name, mimeType, webViewLink, modifiedTime, createdTime)"
|
||||
if not is_slim:
|
||||
fields += ", files(permissions, permissionIds, owners)"
|
||||
|
||||
yield from execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
corpora="user",
|
||||
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
|
||||
q=query,
|
||||
)
|
||||
|
||||
|
||||
# Just in case we need to get the root folder id
|
||||
def get_root_folder_id(service: Resource) -> str:
|
||||
# we dont paginate here because there is only one root folder per user
|
||||
# https://developers.google.com/drive/api/guides/v2-to-v3-reference
|
||||
return service.files().get(fileId="root", fields="id").execute()["id"]
|
||||
@@ -1,18 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class GDriveMimeType(str, Enum):
|
||||
DOC = "application/vnd.google-apps.document"
|
||||
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
|
||||
PDF = "application/pdf"
|
||||
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
PPT = "application/vnd.google-apps.presentation"
|
||||
POWERPOINT = (
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
)
|
||||
PLAIN_TEXT = "text/plain"
|
||||
MARKDOWN = "text/markdown"
|
||||
|
||||
|
||||
GoogleDriveFileType = dict[str, Any]
|
||||
@@ -1,105 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.connectors.google_utils.resources import GoogleDocsService
|
||||
from danswer.connectors.models import Section
|
||||
|
||||
|
||||
class CurrentHeading(BaseModel):
|
||||
id: str
|
||||
text: str
|
||||
|
||||
|
||||
def _build_gdoc_section_link(doc_id: str, heading_id: str) -> str:
|
||||
"""Builds a Google Doc link that jumps to a specific heading"""
|
||||
# NOTE: doesn't support docs with multiple tabs atm, if we need that ask
|
||||
# @Chris
|
||||
return (
|
||||
f"https://docs.google.com/document/d/{doc_id}/edit?tab=t.0#heading={heading_id}"
|
||||
)
|
||||
|
||||
|
||||
def _extract_id_from_heading(paragraph: dict[str, Any]) -> str:
|
||||
"""Extracts the id from a heading paragraph element"""
|
||||
return paragraph["paragraphStyle"]["headingId"]
|
||||
|
||||
|
||||
def _extract_text_from_paragraph(paragraph: dict[str, Any]) -> str:
|
||||
"""Extracts the text content from a paragraph element"""
|
||||
text_elements = []
|
||||
for element in paragraph.get("elements", []):
|
||||
if "textRun" in element:
|
||||
text_elements.append(element["textRun"].get("content", ""))
|
||||
return "".join(text_elements)
|
||||
|
||||
|
||||
def get_document_sections(
|
||||
docs_service: GoogleDocsService,
|
||||
doc_id: str,
|
||||
) -> list[Section]:
|
||||
"""Extracts sections from a Google Doc, including their headings and content"""
|
||||
# Fetch the document structure
|
||||
doc = docs_service.documents().get(documentId=doc_id).execute()
|
||||
|
||||
# Get the content
|
||||
content = doc.get("body", {}).get("content", [])
|
||||
|
||||
sections: list[Section] = []
|
||||
current_section: list[str] = []
|
||||
current_heading: CurrentHeading | None = None
|
||||
|
||||
for element in content:
|
||||
if "paragraph" not in element:
|
||||
continue
|
||||
|
||||
paragraph = element["paragraph"]
|
||||
|
||||
# Check if this is a heading
|
||||
if (
|
||||
"paragraphStyle" in paragraph
|
||||
and "namedStyleType" in paragraph["paragraphStyle"]
|
||||
):
|
||||
style = paragraph["paragraphStyle"]["namedStyleType"]
|
||||
is_heading = style.startswith("HEADING_")
|
||||
is_title = style.startswith("TITLE")
|
||||
|
||||
if is_heading or is_title:
|
||||
# If we were building a previous section, add it to sections list
|
||||
if current_heading is not None and current_section:
|
||||
heading_text = current_heading.text
|
||||
section_text = f"{heading_text}\n" + "\n".join(current_section)
|
||||
sections.append(
|
||||
Section(
|
||||
text=section_text.strip(),
|
||||
link=_build_gdoc_section_link(doc_id, current_heading.id),
|
||||
)
|
||||
)
|
||||
current_section = []
|
||||
|
||||
# Start new heading
|
||||
heading_id = _extract_id_from_heading(paragraph)
|
||||
heading_text = _extract_text_from_paragraph(paragraph)
|
||||
current_heading = CurrentHeading(
|
||||
id=heading_id,
|
||||
text=heading_text,
|
||||
)
|
||||
continue
|
||||
|
||||
# Add content to current section
|
||||
if current_heading is not None:
|
||||
text = _extract_text_from_paragraph(paragraph)
|
||||
if text.strip():
|
||||
current_section.append(text)
|
||||
|
||||
# Don't forget to add the last section
|
||||
if current_heading is not None and current_section:
|
||||
section_text = f"{current_heading.text}\n" + "\n".join(current_section)
|
||||
sections.append(
|
||||
Section(
|
||||
text=section_text.strip(),
|
||||
link=_build_gdoc_section_link(doc_id, current_heading.id),
|
||||
)
|
||||
)
|
||||
|
||||
return sections
|
||||
@@ -1,107 +0,0 @@
|
||||
import json
|
||||
from typing import cast
|
||||
|
||||
from google.auth.transport.requests import Request # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
||||
)
|
||||
from danswer.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
)
|
||||
from danswer.connectors.google_utils.shared_constants import (
|
||||
GOOGLE_SCOPES,
|
||||
)
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_google_oauth_creds(
|
||||
token_json_str: str, source: DocumentSource
|
||||
) -> OAuthCredentials | None:
|
||||
creds_json = json.loads(token_json_str)
|
||||
creds = OAuthCredentials.from_authorized_user_info(
|
||||
info=creds_json,
|
||||
scopes=GOOGLE_SCOPES[source],
|
||||
)
|
||||
if creds.valid:
|
||||
return creds
|
||||
|
||||
if creds.expired and creds.refresh_token:
|
||||
try:
|
||||
creds.refresh(Request())
|
||||
if creds.valid:
|
||||
logger.notice("Refreshed Google Drive tokens.")
|
||||
return creds
|
||||
except Exception:
|
||||
logger.exception("Failed to refresh google drive access token due to:")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_google_creds(
|
||||
credentials: dict[str, str],
|
||||
source: DocumentSource,
|
||||
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
|
||||
"""Checks for two different types of credentials.
|
||||
(1) A credential which holds a token acquired via a user going thorough
|
||||
the Google OAuth flow.
|
||||
(2) A credential which holds a service account key JSON file, which
|
||||
can then be used to impersonate any user in the workspace.
|
||||
"""
|
||||
oauth_creds = None
|
||||
service_creds = None
|
||||
new_creds_dict = None
|
||||
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
|
||||
# OAUTH
|
||||
access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY])
|
||||
oauth_creds = get_google_oauth_creds(
|
||||
token_json_str=access_token_json_str, source=source
|
||||
)
|
||||
|
||||
# tell caller to update token stored in DB if it has changed
|
||||
# (e.g. the token has been refreshed)
|
||||
new_creds_json_str = oauth_creds.to_json() if oauth_creds else ""
|
||||
if new_creds_json_str != access_token_json_str:
|
||||
new_creds_dict = {
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY
|
||||
],
|
||||
}
|
||||
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
|
||||
# SERVICE ACCOUNT
|
||||
service_account_key_json_str = credentials[
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
|
||||
]
|
||||
service_account_key = json.loads(service_account_key_json_str)
|
||||
|
||||
service_creds = ServiceAccountCredentials.from_service_account_info(
|
||||
service_account_key, scopes=GOOGLE_SCOPES[source]
|
||||
)
|
||||
|
||||
if not service_creds.valid or not service_creds.expired:
|
||||
service_creds.refresh(Request())
|
||||
|
||||
if not service_creds.valid:
|
||||
raise PermissionError(
|
||||
f"Unable to access {source} - service account credentials are invalid."
|
||||
)
|
||||
|
||||
creds: ServiceAccountCredentials | OAuthCredentials | None = (
|
||||
oauth_creds or service_creds
|
||||
)
|
||||
if creds is None:
|
||||
raise PermissionError(
|
||||
f"Unable to access {source} - unknown credential structure."
|
||||
)
|
||||
|
||||
return creds, new_creds_dict
|
||||
@@ -1,237 +0,0 @@
|
||||
import json
|
||||
from typing import cast
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import ParseResult
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import KV_CRED_KEY
|
||||
from danswer.configs.constants import KV_GMAIL_CRED_KEY
|
||||
from danswer.configs.constants import KV_GMAIL_SERVICE_ACCOUNT_KEY
|
||||
from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
|
||||
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
from danswer.connectors.google_utils.resources import get_drive_service
|
||||
from danswer.connectors.google_utils.resources import get_gmail_service
|
||||
from danswer.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
||||
)
|
||||
from danswer.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
)
|
||||
from danswer.connectors.google_utils.shared_constants import (
|
||||
GOOGLE_SCOPES,
|
||||
)
|
||||
from danswer.connectors.google_utils.shared_constants import (
|
||||
MISSING_SCOPES_ERROR_STR,
|
||||
)
|
||||
from danswer.connectors.google_utils.shared_constants import (
|
||||
ONYX_SCOPE_INSTRUCTIONS,
|
||||
)
|
||||
from danswer.db.credentials import update_credential_json
|
||||
from danswer.db.models import User
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.server.documents.models import GoogleAppCredentials
|
||||
from danswer.server.documents.models import GoogleServiceAccountKey
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _build_frontend_google_drive_redirect(source: DocumentSource) -> str:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
|
||||
elif source == DocumentSource.GMAIL:
|
||||
return f"{WEB_DOMAIN}/admin/connectors/gmail/auth/callback"
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
|
||||
|
||||
def _get_current_oauth_user(creds: OAuthCredentials, source: DocumentSource) -> str:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
drive_service = get_drive_service(creds)
|
||||
user_info = (
|
||||
drive_service.about()
|
||||
.get(
|
||||
fields="user(emailAddress)",
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
email = user_info.get("user", {}).get("emailAddress")
|
||||
elif source == DocumentSource.GMAIL:
|
||||
gmail_service = get_gmail_service(creds)
|
||||
user_info = (
|
||||
gmail_service.users()
|
||||
.getProfile(
|
||||
userId="me",
|
||||
fields="emailAddress",
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
email = user_info.get("emailAddress")
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
return email
|
||||
|
||||
|
||||
def verify_csrf(credential_id: int, state: str) -> None:
|
||||
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
|
||||
if csrf != state:
|
||||
raise PermissionError(
|
||||
"State from Google Drive Connector callback does not match expected"
|
||||
)
|
||||
|
||||
|
||||
def update_credential_access_tokens(
|
||||
auth_code: str,
|
||||
credential_id: int,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
) -> OAuthCredentials | None:
|
||||
app_credentials = get_google_app_cred(source)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
app_credentials.model_dump(),
|
||||
scopes=GOOGLE_SCOPES[source],
|
||||
redirect_uri=_build_frontend_google_drive_redirect(source),
|
||||
)
|
||||
flow.fetch_token(code=auth_code)
|
||||
creds = flow.credentials
|
||||
token_json_str = creds.to_json()
|
||||
|
||||
# Get user email from Google API so we know who
|
||||
# the primary admin is for this connector
|
||||
try:
|
||||
email = _get_current_oauth_user(creds, source)
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
|
||||
new_creds_dict = {
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
|
||||
}
|
||||
|
||||
if not update_credential_json(credential_id, new_creds_dict, user, db_session):
|
||||
return None
|
||||
return creds
|
||||
|
||||
|
||||
def build_service_account_creds(
|
||||
source: DocumentSource,
|
||||
primary_admin_email: str | None = None,
|
||||
) -> CredentialBase:
|
||||
service_account_key = get_service_account_key(source=source)
|
||||
|
||||
credential_dict = {
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: service_account_key.json(),
|
||||
}
|
||||
if primary_admin_email:
|
||||
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = primary_admin_email
|
||||
|
||||
return CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=source,
|
||||
)
|
||||
|
||||
|
||||
def get_auth_url(credential_id: int, source: DocumentSource) -> str:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
elif source == DocumentSource.GMAIL:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
scopes=GOOGLE_SCOPES[source],
|
||||
redirect_uri=_build_frontend_google_drive_redirect(source),
|
||||
)
|
||||
auth_url, _ = flow.authorization_url(prompt="consent")
|
||||
|
||||
parsed_url = cast(ParseResult, urlparse(auth_url))
|
||||
params = parse_qs(parsed_url.query)
|
||||
|
||||
get_kv_store().store(
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
) # type: ignore
|
||||
return str(auth_url)
|
||||
|
||||
|
||||
def get_google_app_cred(source: DocumentSource) -> GoogleAppCredentials:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
elif source == DocumentSource.GMAIL:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
return GoogleAppCredentials(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_google_app_cred(
|
||||
app_credentials: GoogleAppCredentials, source: DocumentSource
|
||||
) -> None:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
get_kv_store().store(
|
||||
KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True)
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
|
||||
|
||||
def delete_google_app_cred(source: DocumentSource) -> None:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
get_kv_store().delete(KV_GOOGLE_DRIVE_CRED_KEY)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
get_kv_store().delete(KV_GMAIL_CRED_KEY)
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
|
||||
|
||||
def get_service_account_key(source: DocumentSource) -> GoogleServiceAccountKey:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
|
||||
elif source == DocumentSource.GMAIL:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
return GoogleServiceAccountKey(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_service_account_key(
|
||||
service_account_key: GoogleServiceAccountKey, source: DocumentSource
|
||||
) -> None:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
get_kv_store().store(
|
||||
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY,
|
||||
service_account_key.json(),
|
||||
encrypt=True,
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
|
||||
|
||||
def delete_service_account_key(source: DocumentSource) -> None:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
get_kv_store().delete(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
get_kv_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
@@ -1,125 +0,0 @@
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from danswer.connectors.google_drive.models import GoogleDriveFileType
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# Google Drive APIs are quite flakey and may 500 for an
|
||||
# extended period of time. Trying to combat here by adding a very
|
||||
# long retry period (~20 minutes of trying every minute)
|
||||
add_retries = retry_builder(tries=50, max_delay=30)
|
||||
|
||||
|
||||
def _execute_with_retry(request: Any) -> Any:
|
||||
max_attempts = 10
|
||||
attempt = 1
|
||||
|
||||
while attempt < max_attempts:
|
||||
# Note for reasons unknown, the Google API will sometimes return a 429
|
||||
# and even after waiting the retry period, it will return another 429.
|
||||
# It could be due to a few possibilities:
|
||||
# 1. Other things are also requesting from the Gmail API with the same key
|
||||
# 2. It's a rolling rate limit so the moment we get some amount of requests cleared, we hit it again very quickly
|
||||
# 3. The retry-after has a maximum and we've already hit the limit for the day
|
||||
# or it's something else...
|
||||
try:
|
||||
return request.execute()
|
||||
except HttpError as error:
|
||||
attempt += 1
|
||||
|
||||
if error.resp.status == 429:
|
||||
# Attempt to get 'Retry-After' from headers
|
||||
retry_after = error.resp.get("Retry-After")
|
||||
if retry_after:
|
||||
sleep_time = int(retry_after)
|
||||
else:
|
||||
# Extract 'Retry after' timestamp from error message
|
||||
match = re.search(
|
||||
r"Retry after (\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+Z)",
|
||||
str(error),
|
||||
)
|
||||
if match:
|
||||
retry_after_timestamp = match.group(1)
|
||||
retry_after_dt = datetime.strptime(
|
||||
retry_after_timestamp, "%Y-%m-%dT%H:%M:%S.%fZ"
|
||||
).replace(tzinfo=timezone.utc)
|
||||
current_time = datetime.now(timezone.utc)
|
||||
sleep_time = max(
|
||||
int((retry_after_dt - current_time).total_seconds()),
|
||||
0,
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"No Retry-After header or timestamp found in error message: {error}"
|
||||
)
|
||||
sleep_time = 60
|
||||
|
||||
sleep_time += 3 # Add a buffer to be safe
|
||||
|
||||
logger.info(
|
||||
f"Rate limit exceeded. Attempt {attempt}/{max_attempts}. Sleeping for {sleep_time} seconds."
|
||||
)
|
||||
time.sleep(sleep_time)
|
||||
|
||||
else:
|
||||
raise
|
||||
|
||||
# If we've exhausted all attempts
|
||||
raise Exception(f"Failed to execute request after {max_attempts} attempts")
|
||||
|
||||
|
||||
def execute_paginated_retrieval(
|
||||
retrieval_function: Callable,
|
||||
list_key: str | None = None,
|
||||
continue_on_404_or_403: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
"""Execute a paginated retrieval from Google Drive API
|
||||
Args:
|
||||
retrieval_function: The specific list function to call (e.g., service.files().list)
|
||||
**kwargs: Arguments to pass to the list function
|
||||
"""
|
||||
next_page_token = ""
|
||||
while next_page_token is not None:
|
||||
request_kwargs = kwargs.copy()
|
||||
if next_page_token:
|
||||
request_kwargs["pageToken"] = next_page_token
|
||||
|
||||
try:
|
||||
results = retrieval_function(**request_kwargs).execute()
|
||||
except HttpError as e:
|
||||
if e.resp.status >= 500:
|
||||
results = add_retries(
|
||||
lambda: retrieval_function(**request_kwargs).execute()
|
||||
)()
|
||||
elif e.resp.status == 404 or e.resp.status == 403:
|
||||
if continue_on_404_or_403:
|
||||
logger.warning(f"Error executing request: {e}")
|
||||
results = {}
|
||||
else:
|
||||
raise e
|
||||
elif e.resp.status == 429:
|
||||
results = _execute_with_retry(
|
||||
lambda: retrieval_function(**request_kwargs).execute()
|
||||
)
|
||||
else:
|
||||
logger.exception("Error executing request:")
|
||||
raise e
|
||||
|
||||
next_page_token = results.get("nextPageToken")
|
||||
if list_key:
|
||||
for item in results.get(list_key, []):
|
||||
yield item
|
||||
else:
|
||||
yield results
|
||||
@@ -1,63 +0,0 @@
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient.discovery import build # type: ignore
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
|
||||
|
||||
class GoogleDriveService(Resource):
|
||||
pass
|
||||
|
||||
|
||||
class GoogleDocsService(Resource):
|
||||
pass
|
||||
|
||||
|
||||
class AdminService(Resource):
|
||||
pass
|
||||
|
||||
|
||||
class GmailService(Resource):
|
||||
pass
|
||||
|
||||
|
||||
def _get_google_service(
|
||||
service_name: str,
|
||||
service_version: str,
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> GoogleDriveService | GoogleDocsService | AdminService | GmailService:
|
||||
if isinstance(creds, ServiceAccountCredentials):
|
||||
creds = creds.with_subject(user_email)
|
||||
service = build(service_name, service_version, credentials=creds)
|
||||
elif isinstance(creds, OAuthCredentials):
|
||||
service = build(service_name, service_version, credentials=creds)
|
||||
|
||||
return service
|
||||
|
||||
|
||||
def get_google_docs_service(
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> GoogleDocsService:
|
||||
return _get_google_service("docs", "v1", creds, user_email)
|
||||
|
||||
|
||||
def get_drive_service(
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> GoogleDriveService:
|
||||
return _get_google_service("drive", "v3", creds, user_email)
|
||||
|
||||
|
||||
def get_admin_service(
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> AdminService:
|
||||
return _get_google_service("admin", "directory_v1", creds, user_email)
|
||||
|
||||
|
||||
def get_gmail_service(
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> GmailService:
|
||||
return _get_google_service("gmail", "v1", creds, user_email)
|
||||
@@ -1,40 +0,0 @@
|
||||
from danswer.configs.constants import DocumentSource
|
||||
|
||||
# NOTE: do not need https://www.googleapis.com/auth/documents.readonly
|
||||
# this is counted under `/auth/drive.readonly`
|
||||
GOOGLE_SCOPES = {
|
||||
DocumentSource.GOOGLE_DRIVE: [
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly",
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly",
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly",
|
||||
],
|
||||
DocumentSource.GMAIL: [
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly",
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly",
|
||||
],
|
||||
}
|
||||
|
||||
# This is the Oauth token
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
|
||||
# This is the service account key
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
|
||||
# The email saved for both auth types
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
|
||||
|
||||
USER_FIELDS = "nextPageToken, users(primaryEmail)"
|
||||
|
||||
# Error message substrings
|
||||
MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested"
|
||||
|
||||
# Documentation and error messages
|
||||
SCOPE_DOC_URL = "https://docs.danswer.dev/connectors/google_drive/overview"
|
||||
ONYX_SCOPE_INSTRUCTIONS = (
|
||||
"You have upgraded Danswer without updating the Google Auth scopes. "
|
||||
f"Please refer to the documentation to learn how to update the scopes: {SCOPE_DOC_URL}"
|
||||
)
|
||||
|
||||
|
||||
# This is the maximum number of threads that can be retrieved at once
|
||||
SLIM_BATCH_SIZE = 500
|
||||
@@ -56,11 +56,7 @@ class PollConnector(BaseConnector):
|
||||
|
||||
class SlimConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import builtins
|
||||
import functools
|
||||
import itertools
|
||||
import tempfile
|
||||
from typing import Any
|
||||
from unittest import mock
|
||||
from urllib.parse import urlparse
|
||||
@@ -19,8 +18,6 @@ from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
pywikibot.config.base_dir = tempfile.TemporaryDirectory().name
|
||||
|
||||
|
||||
@mock.patch.object(
|
||||
builtins, "print", lambda *args: logger.info("\t".join(map(str, args)))
|
||||
|
||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import itertools
|
||||
import tempfile
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
@@ -26,8 +25,6 @@ from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
pywikibot.config.base_dir = tempfile.TemporaryDirectory().name
|
||||
|
||||
|
||||
def pywikibot_timestamp_to_utc_datetime(
|
||||
timestamp: pywikibot.time.Timestamp,
|
||||
@@ -124,6 +121,7 @@ class MediaWikiConnector(LoadConnector, PollConnector):
|
||||
self.batch_size = batch_size
|
||||
|
||||
# short names can only have ascii letters and digits
|
||||
|
||||
self.family = family_class_dispatch(hostname, "WikipediaConnector")()
|
||||
self.site = pywikibot.Site(fam=self.family, code=language_code)
|
||||
self.categories = [
|
||||
|
||||
@@ -251,11 +251,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
end_datetime = datetime.utcfromtimestamp(end)
|
||||
return self._fetch_from_salesforce(start=start_datetime, end=end_datetime)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
|
||||
@@ -391,11 +391,7 @@ class SlackPollConnector(PollConnector, SlimConnector):
|
||||
self.client = WebClient(token=bot_token)
|
||||
return None
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from retry import retry
|
||||
from zenpy import Zenpy # type: ignore
|
||||
from zenpy.lib.api_objects import Ticket # type: ignore
|
||||
from zenpy.lib.api_objects.help_centre_objects import Article # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.app_configs import ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
|
||||
@@ -17,244 +20,43 @@ from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.html_utils import parse_html_page_basic
|
||||
from danswer.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
MAX_PAGE_SIZE = 30 # Zendesk API maximum
|
||||
|
||||
|
||||
class ZendeskCredentialsNotSetUpError(PermissionError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
"Zendesk Credentials are not set up, was load_credentials called?"
|
||||
)
|
||||
|
||||
|
||||
class ZendeskClient:
|
||||
def __init__(self, subdomain: str, email: str, token: str):
|
||||
self.base_url = f"https://{subdomain}.zendesk.com/api/v2"
|
||||
self.auth = (f"{email}/token", token)
|
||||
|
||||
@retry_builder()
|
||||
def make_request(self, endpoint: str, params: dict[str, Any]) -> dict[str, Any]:
|
||||
response = requests.get(
|
||||
f"{self.base_url}/{endpoint}", auth=self.auth, params=params
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def _get_content_tag_mapping(client: ZendeskClient) -> dict[str, str]:
|
||||
content_tags: dict[str, str] = {}
|
||||
params = {"page[size]": MAX_PAGE_SIZE}
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = client.make_request("guide/content_tags", params)
|
||||
|
||||
for tag in data.get("records", []):
|
||||
content_tags[tag["id"]] = tag["name"]
|
||||
|
||||
# Check if there are more pages
|
||||
if data.get("meta", {}).get("has_more", False):
|
||||
params["page[after]"] = data["meta"]["after_cursor"]
|
||||
else:
|
||||
break
|
||||
|
||||
return content_tags
|
||||
except Exception as e:
|
||||
raise Exception(f"Error fetching content tags: {str(e)}")
|
||||
|
||||
|
||||
def _get_articles(
|
||||
client: ZendeskClient, start_time: int | None = None, page_size: int = MAX_PAGE_SIZE
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
params = (
|
||||
{"start_time": start_time, "page[size]": page_size}
|
||||
if start_time
|
||||
else {"page[size]": page_size}
|
||||
def _article_to_document(article: Article, content_tags: dict[str, str]) -> Document:
|
||||
author = BasicExpertInfo(
|
||||
display_name=article.author.name, email=article.author.email
|
||||
)
|
||||
update_time = time_str_to_utc(article.updated_at)
|
||||
|
||||
while True:
|
||||
data = client.make_request("help_center/articles", params)
|
||||
for article in data["articles"]:
|
||||
yield article
|
||||
|
||||
if not data.get("meta", {}).get("has_more"):
|
||||
break
|
||||
params["page[after]"] = data["meta"]["after_cursor"]
|
||||
|
||||
|
||||
def _get_tickets(
|
||||
client: ZendeskClient, start_time: int | None = None
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
params = {"start_time": start_time} if start_time else {"start_time": 0}
|
||||
|
||||
while True:
|
||||
data = client.make_request("incremental/tickets.json", params)
|
||||
for ticket in data["tickets"]:
|
||||
yield ticket
|
||||
|
||||
if not data.get("end_of_stream", False):
|
||||
params["start_time"] = data["end_time"]
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
def _fetch_author(client: ZendeskClient, author_id: str) -> BasicExpertInfo | None:
|
||||
author_data = client.make_request(f"users/{author_id}", {})
|
||||
user = author_data.get("user")
|
||||
return (
|
||||
BasicExpertInfo(display_name=user.get("name"), email=user.get("email"))
|
||||
if user and user.get("name") and user.get("email")
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
def _article_to_document(
|
||||
article: dict[str, Any],
|
||||
content_tags: dict[str, str],
|
||||
author_map: dict[str, BasicExpertInfo],
|
||||
client: ZendeskClient,
|
||||
) -> tuple[dict[str, BasicExpertInfo] | None, Document]:
|
||||
author_id = article.get("author_id")
|
||||
if not author_id:
|
||||
author = None
|
||||
else:
|
||||
author = (
|
||||
author_map.get(author_id)
|
||||
if author_id in author_map
|
||||
else _fetch_author(client, author_id)
|
||||
)
|
||||
|
||||
new_author_mapping = {author_id: author} if author_id and author else None
|
||||
|
||||
updated_at = article.get("updated_at")
|
||||
update_time = time_str_to_utc(updated_at) if updated_at else None
|
||||
|
||||
# Build metadata
|
||||
# build metadata
|
||||
metadata: dict[str, str | list[str]] = {
|
||||
"labels": [str(label) for label in article.get("label_names", []) if label],
|
||||
"labels": [str(label) for label in article.label_names if label],
|
||||
"content_tags": [
|
||||
content_tags[tag_id]
|
||||
for tag_id in article.get("content_tag_ids", [])
|
||||
for tag_id in article.content_tag_ids
|
||||
if tag_id in content_tags
|
||||
],
|
||||
}
|
||||
|
||||
# Remove empty values
|
||||
# remove empty values
|
||||
metadata = {k: v for k, v in metadata.items() if v}
|
||||
|
||||
return new_author_mapping, Document(
|
||||
id=f"article:{article['id']}",
|
||||
return Document(
|
||||
id=f"article:{article.id}",
|
||||
sections=[
|
||||
Section(
|
||||
link=article.get("html_url"),
|
||||
text=parse_html_page_basic(article["body"]),
|
||||
)
|
||||
Section(link=article.html_url, text=parse_html_page_basic(article.body))
|
||||
],
|
||||
source=DocumentSource.ZENDESK,
|
||||
semantic_identifier=article["title"],
|
||||
semantic_identifier=article.title,
|
||||
doc_updated_at=update_time,
|
||||
primary_owners=[author] if author else None,
|
||||
primary_owners=[author],
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
def _get_comment_text(
|
||||
comment: dict[str, Any],
|
||||
author_map: dict[str, BasicExpertInfo],
|
||||
client: ZendeskClient,
|
||||
) -> tuple[dict[str, BasicExpertInfo] | None, str]:
|
||||
author_id = comment.get("author_id")
|
||||
if not author_id:
|
||||
author = None
|
||||
else:
|
||||
author = (
|
||||
author_map.get(author_id)
|
||||
if author_id in author_map
|
||||
else _fetch_author(client, author_id)
|
||||
)
|
||||
|
||||
new_author_mapping = {author_id: author} if author_id and author else None
|
||||
|
||||
comment_text = f"Comment{' by ' + author.display_name if author and author.display_name else ''}"
|
||||
comment_text += f"{' at ' + comment['created_at'] if comment.get('created_at') else ''}:\n{comment['body']}"
|
||||
|
||||
return new_author_mapping, comment_text
|
||||
|
||||
|
||||
def _ticket_to_document(
|
||||
ticket: dict[str, Any],
|
||||
author_map: dict[str, BasicExpertInfo],
|
||||
client: ZendeskClient,
|
||||
default_subdomain: str,
|
||||
) -> tuple[dict[str, BasicExpertInfo] | None, Document]:
|
||||
submitter_id = ticket.get("submitter")
|
||||
if not submitter_id:
|
||||
submitter = None
|
||||
else:
|
||||
submitter = (
|
||||
author_map.get(submitter_id)
|
||||
if submitter_id in author_map
|
||||
else _fetch_author(client, submitter_id)
|
||||
)
|
||||
|
||||
new_author_mapping = (
|
||||
{submitter_id: submitter} if submitter_id and submitter else None
|
||||
)
|
||||
|
||||
updated_at = ticket.get("updated_at")
|
||||
update_time = time_str_to_utc(updated_at) if updated_at else None
|
||||
|
||||
metadata: dict[str, str | list[str]] = {}
|
||||
if status := ticket.get("status"):
|
||||
metadata["status"] = status
|
||||
if priority := ticket.get("priority"):
|
||||
metadata["priority"] = priority
|
||||
if tags := ticket.get("tags"):
|
||||
metadata["tags"] = tags
|
||||
if ticket_type := ticket.get("type"):
|
||||
metadata["ticket_type"] = ticket_type
|
||||
|
||||
# Fetch comments for the ticket
|
||||
comments_data = client.make_request(f"tickets/{ticket.get('id')}/comments", {})
|
||||
comments = comments_data.get("comments", [])
|
||||
|
||||
comment_texts = []
|
||||
for comment in comments:
|
||||
new_author_mapping, comment_text = _get_comment_text(
|
||||
comment, author_map, client
|
||||
)
|
||||
if new_author_mapping:
|
||||
author_map.update(new_author_mapping)
|
||||
comment_texts.append(comment_text)
|
||||
|
||||
comments_text = "\n\n".join(comment_texts)
|
||||
|
||||
subject = ticket.get("subject")
|
||||
full_text = f"Ticket Subject:\n{subject}\n\nComments:\n{comments_text}"
|
||||
|
||||
ticket_url = ticket.get("url")
|
||||
subdomain = (
|
||||
ticket_url.split("//")[1].split(".zendesk.com")[0]
|
||||
if ticket_url
|
||||
else default_subdomain
|
||||
)
|
||||
|
||||
ticket_display_url = (
|
||||
f"https://{subdomain}.zendesk.com/agent/tickets/{ticket.get('id')}"
|
||||
)
|
||||
|
||||
return new_author_mapping, Document(
|
||||
id=f"zendesk_ticket_{ticket['id']}",
|
||||
sections=[Section(link=ticket_display_url, text=full_text)],
|
||||
source=DocumentSource.ZENDESK,
|
||||
semantic_identifier=f"Ticket #{ticket['id']}: {subject or 'No Subject'}",
|
||||
doc_updated_at=update_time,
|
||||
primary_owners=[submitter] if submitter else None,
|
||||
metadata=metadata,
|
||||
)
|
||||
class ZendeskClientNotSetUpError(PermissionError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Zendesk Client is not set up, was load_credentials called?")
|
||||
|
||||
|
||||
class ZendeskConnector(LoadConnector, PollConnector):
|
||||
@@ -264,10 +66,44 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
content_type: str = "articles",
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.content_type = content_type
|
||||
self.subdomain = ""
|
||||
# Fetch all tags ahead of time
|
||||
self.zendesk_client: Zenpy | None = None
|
||||
self.content_tags: dict[str, str] = {}
|
||||
self.content_type = content_type
|
||||
|
||||
@retry(tries=3, delay=2, backoff=2)
|
||||
def _set_content_tags(
|
||||
self, subdomain: str, email: str, token: str, page_size: int = 30
|
||||
) -> None:
|
||||
# Construct the base URL
|
||||
base_url = f"https://{subdomain}.zendesk.com/api/v2/guide/content_tags"
|
||||
|
||||
# Set up authentication
|
||||
auth = (f"{email}/token", token)
|
||||
|
||||
# Set up pagination parameters
|
||||
params = {"page[size]": page_size}
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Make the GET request
|
||||
response = requests.get(base_url, auth=auth, params=params)
|
||||
|
||||
# Check if the request was successful
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
content_tag_list = data.get("records", [])
|
||||
for tag in content_tag_list:
|
||||
self.content_tags[tag["id"]] = tag["name"]
|
||||
|
||||
# Check if there are more pages
|
||||
if data.get("meta", {}).get("has_more", False):
|
||||
params["page[after]"] = data["meta"]["after_cursor"]
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raise Exception(f"Error: {response.status_code}\n{response.text}")
|
||||
except Exception as e:
|
||||
raise Exception(f"Error fetching content tags: {str(e)}")
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# Subdomain is actually the whole URL
|
||||
@@ -276,23 +112,87 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
.replace("https://", "")
|
||||
.split(".zendesk.com")[0]
|
||||
)
|
||||
self.subdomain = subdomain
|
||||
|
||||
self.client = ZendeskClient(
|
||||
subdomain, credentials["zendesk_email"], credentials["zendesk_token"]
|
||||
self.zendesk_client = Zenpy(
|
||||
subdomain=subdomain,
|
||||
email=credentials["zendesk_email"],
|
||||
token=credentials["zendesk_token"],
|
||||
)
|
||||
self._set_content_tags(
|
||||
subdomain,
|
||||
credentials["zendesk_email"],
|
||||
credentials["zendesk_token"],
|
||||
)
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self.poll_source(None, None)
|
||||
|
||||
def _ticket_to_document(self, ticket: Ticket) -> Document:
|
||||
if self.zendesk_client is None:
|
||||
raise ZendeskClientNotSetUpError()
|
||||
|
||||
owner = None
|
||||
if ticket.requester and ticket.requester.name and ticket.requester.email:
|
||||
owner = [
|
||||
BasicExpertInfo(
|
||||
display_name=ticket.requester.name, email=ticket.requester.email
|
||||
)
|
||||
]
|
||||
update_time = time_str_to_utc(ticket.updated_at) if ticket.updated_at else None
|
||||
|
||||
metadata: dict[str, str | list[str]] = {}
|
||||
if ticket.status is not None:
|
||||
metadata["status"] = ticket.status
|
||||
if ticket.priority is not None:
|
||||
metadata["priority"] = ticket.priority
|
||||
if ticket.tags:
|
||||
metadata["tags"] = ticket.tags
|
||||
if ticket.type is not None:
|
||||
metadata["ticket_type"] = ticket.type
|
||||
|
||||
# Fetch comments for the ticket
|
||||
comments = self.zendesk_client.tickets.comments(ticket=ticket)
|
||||
|
||||
# Combine all comments into a single text
|
||||
comments_text = "\n\n".join(
|
||||
[
|
||||
f"Comment{f' by {comment.author.name}' if comment.author and comment.author.name else ''}"
|
||||
f"{f' at {comment.created_at}' if comment.created_at else ''}:\n{comment.body}"
|
||||
for comment in comments
|
||||
if comment.body
|
||||
]
|
||||
)
|
||||
|
||||
# Combine ticket description and comments
|
||||
description = (
|
||||
ticket.description
|
||||
if hasattr(ticket, "description") and ticket.description
|
||||
else ""
|
||||
)
|
||||
full_text = f"Ticket Description:\n{description}\n\nComments:\n{comments_text}"
|
||||
|
||||
# Extract subdomain from ticket.url
|
||||
subdomain = ticket.url.split("//")[1].split(".zendesk.com")[0]
|
||||
|
||||
# Build the html url for the ticket
|
||||
ticket_url = f"https://{subdomain}.zendesk.com/agent/tickets/{ticket.id}"
|
||||
|
||||
return Document(
|
||||
id=f"zendesk_ticket_{ticket.id}",
|
||||
sections=[Section(link=ticket_url, text=full_text)],
|
||||
source=DocumentSource.ZENDESK,
|
||||
semantic_identifier=f"Ticket #{ticket.id}: {ticket.subject or 'No Subject'}",
|
||||
doc_updated_at=update_time,
|
||||
primary_owners=owner,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.client is None:
|
||||
raise ZendeskCredentialsNotSetUpError()
|
||||
|
||||
self.content_tags = _get_content_tag_mapping(self.client)
|
||||
if self.zendesk_client is None:
|
||||
raise ZendeskClientNotSetUpError()
|
||||
|
||||
if self.content_type == "articles":
|
||||
yield from self._poll_articles(start)
|
||||
@@ -304,30 +204,26 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
def _poll_articles(
|
||||
self, start: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
articles = _get_articles(self.client, start_time=int(start) if start else None)
|
||||
|
||||
# This one is built on the fly as there may be more many more authors than tags
|
||||
author_map: dict[str, BasicExpertInfo] = {}
|
||||
|
||||
articles = (
|
||||
self.zendesk_client.help_center.articles(cursor_pagination=True) # type: ignore
|
||||
if start is None
|
||||
else self.zendesk_client.help_center.articles.incremental( # type: ignore
|
||||
start_time=int(start)
|
||||
)
|
||||
)
|
||||
doc_batch = []
|
||||
for article in articles:
|
||||
if (
|
||||
article.get("body") is None
|
||||
or article.get("draft")
|
||||
article.body is None
|
||||
or article.draft
|
||||
or any(
|
||||
label in ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
|
||||
for label in article.get("label_names", [])
|
||||
for label in article.label_names
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
new_author_map, documents = _article_to_document(
|
||||
article, self.content_tags, author_map, self.client
|
||||
)
|
||||
if new_author_map:
|
||||
author_map.update(new_author_map)
|
||||
|
||||
doc_batch.append(documents)
|
||||
doc_batch.append(_article_to_document(article, self.content_tags))
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch.clear()
|
||||
@@ -338,14 +234,10 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
def _poll_tickets(
|
||||
self, start: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.client is None:
|
||||
raise ZendeskCredentialsNotSetUpError()
|
||||
if self.zendesk_client is None:
|
||||
raise ZendeskClientNotSetUpError()
|
||||
|
||||
author_map: dict[str, BasicExpertInfo] = {}
|
||||
|
||||
ticket_generator = _get_tickets(
|
||||
self.client, start_time=int(start) if start else None
|
||||
)
|
||||
ticket_generator = self.zendesk_client.tickets.incremental(start_time=start)
|
||||
|
||||
while True:
|
||||
doc_batch = []
|
||||
@@ -354,20 +246,10 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
ticket = next(ticket_generator)
|
||||
|
||||
# Check if the ticket status is deleted and skip it if so
|
||||
if ticket.get("status") == "deleted":
|
||||
if ticket.status == "deleted":
|
||||
continue
|
||||
|
||||
new_author_map, documents = _ticket_to_document(
|
||||
ticket=ticket,
|
||||
author_map=author_map,
|
||||
client=self.client,
|
||||
default_subdomain=self.subdomain,
|
||||
)
|
||||
|
||||
if new_author_map:
|
||||
author_map.update(new_author_map)
|
||||
|
||||
doc_batch.append(documents)
|
||||
doc_batch.append(self._ticket_to_document(ticket))
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
@@ -385,6 +267,7 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
import time
|
||||
|
||||
connector = ZendeskConnector()
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import os
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import SlackBotConfig
|
||||
@@ -50,16 +48,3 @@ def validate_channel_names(
|
||||
)
|
||||
|
||||
return cleaned_channel_names
|
||||
|
||||
|
||||
# Scaling configurations for multi-tenant Slack bot handling
|
||||
TENANT_LOCK_EXPIRATION = 1800 # How long a pod can hold exclusive access to a tenant before other pods can acquire it
|
||||
TENANT_HEARTBEAT_INTERVAL = (
|
||||
15 # How often pods send heartbeats to indicate they are still processing a tenant
|
||||
)
|
||||
TENANT_HEARTBEAT_EXPIRATION = (
|
||||
30 # How long before a tenant's heartbeat expires, allowing other pods to take over
|
||||
)
|
||||
TENANT_ACQUISITION_INTERVAL = 60 # How often pods attempt to acquire unprocessed tenants and checks for new tokens
|
||||
|
||||
MAX_TENANTS_PER_POD = int(os.getenv("MAX_TENANTS_PER_POD", 50))
|
||||
|
||||
@@ -1,34 +1,18 @@
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from threading import Event
|
||||
from types import FrameType
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import Set
|
||||
|
||||
from prometheus_client import Gauge
|
||||
from prometheus_client import start_http_server
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
|
||||
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
|
||||
from danswer.connectors.slack.utils import expert_info_from_slack_id
|
||||
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
|
||||
from danswer.danswerbot.slack.config import MAX_TENANTS_PER_POD
|
||||
from danswer.danswerbot.slack.config import TENANT_ACQUISITION_INTERVAL
|
||||
from danswer.danswerbot.slack.config import TENANT_HEARTBEAT_EXPIRATION
|
||||
from danswer.danswerbot.slack.config import TENANT_HEARTBEAT_INTERVAL
|
||||
from danswer.danswerbot.slack.config import TENANT_LOCK_EXPIRATION
|
||||
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
|
||||
@@ -62,7 +46,6 @@ from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
|
||||
from danswer.danswerbot.slack.utils import rephrase_slack_message
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import TenantSocketModeClient
|
||||
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
@@ -70,24 +53,17 @@ from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import DISALLOWED_SLACK_BOT_TENANT_LIST
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import SLACK_CHANNEL_ID
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Prometheus metric for HPA
|
||||
active_tenants_gauge = Gauge(
|
||||
"active_tenants", "Number of active tenants handled by this pod"
|
||||
)
|
||||
|
||||
# In rare cases, some users have been experiencing a massive amount of trivial messages coming through
|
||||
# to the Slack Bot with trivial messages. Adding this to avoid exploding LLM costs while we track down
|
||||
# the cause.
|
||||
@@ -101,232 +77,10 @@ _SLACK_GREETINGS_TO_IGNORE = {
|
||||
":wave:",
|
||||
}
|
||||
|
||||
# This is always (currently) the user id of Slack's official slackbot
|
||||
# this is always (currently) the user id of Slack's official slackbot
|
||||
_OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
|
||||
|
||||
|
||||
class SlackbotHandler:
|
||||
def __init__(self) -> None:
|
||||
logger.info("Initializing SlackbotHandler")
|
||||
self.tenant_ids: Set[str | None] = set()
|
||||
self.socket_clients: Dict[str | None, TenantSocketModeClient] = {}
|
||||
self.slack_bot_tokens: Dict[str | None, SlackBotTokens] = {}
|
||||
self.running = True
|
||||
self.pod_id = self.get_pod_id()
|
||||
self._shutdown_event = Event()
|
||||
logger.info(f"Pod ID: {self.pod_id}")
|
||||
|
||||
# Set up signal handlers for graceful shutdown
|
||||
signal.signal(signal.SIGTERM, self.shutdown)
|
||||
signal.signal(signal.SIGINT, self.shutdown)
|
||||
logger.info("Signal handlers registered")
|
||||
|
||||
# Start the Prometheus metrics server
|
||||
logger.info("Starting Prometheus metrics server")
|
||||
start_http_server(8000)
|
||||
logger.info("Prometheus metrics server started")
|
||||
|
||||
# Start background threads
|
||||
logger.info("Starting background threads")
|
||||
self.acquire_thread = threading.Thread(
|
||||
target=self.acquire_tenants_loop, daemon=True
|
||||
)
|
||||
self.heartbeat_thread = threading.Thread(
|
||||
target=self.heartbeat_loop, daemon=True
|
||||
)
|
||||
|
||||
self.acquire_thread.start()
|
||||
self.heartbeat_thread.start()
|
||||
logger.info("Background threads started")
|
||||
|
||||
def get_pod_id(self) -> str:
|
||||
pod_id = os.environ.get("HOSTNAME", "unknown_pod")
|
||||
logger.info(f"Retrieved pod ID: {pod_id}")
|
||||
return pod_id
|
||||
|
||||
def acquire_tenants_loop(self) -> None:
|
||||
while not self._shutdown_event.is_set():
|
||||
try:
|
||||
self.acquire_tenants()
|
||||
active_tenants_gauge.set(len(self.tenant_ids))
|
||||
logger.debug(f"Current active tenants: {len(self.tenant_ids)}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in Slack acquisition: {e}")
|
||||
self._shutdown_event.wait(timeout=TENANT_ACQUISITION_INTERVAL)
|
||||
|
||||
def heartbeat_loop(self) -> None:
|
||||
while not self._shutdown_event.is_set():
|
||||
try:
|
||||
self.send_heartbeats()
|
||||
logger.debug(f"Sent heartbeats for {len(self.tenant_ids)} tenants")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in heartbeat loop: {e}")
|
||||
self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL)
|
||||
|
||||
def acquire_tenants(self) -> None:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if (
|
||||
DISALLOWED_SLACK_BOT_TENANT_LIST is not None
|
||||
and tenant_id in DISALLOWED_SLACK_BOT_TENANT_LIST
|
||||
):
|
||||
logger.debug(f"Tenant {tenant_id} is in the disallowed list, skipping")
|
||||
continue
|
||||
|
||||
if tenant_id in self.tenant_ids:
|
||||
logger.debug(f"Tenant {tenant_id} already in self.tenant_ids")
|
||||
continue
|
||||
|
||||
if len(self.tenant_ids) >= MAX_TENANTS_PER_POD:
|
||||
logger.info(
|
||||
f"Max tenants per pod reached ({MAX_TENANTS_PER_POD}) Not acquiring any more tenants"
|
||||
)
|
||||
break
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
pod_id = self.pod_id
|
||||
acquired = redis_client.set(
|
||||
DanswerRedisLocks.SLACK_BOT_LOCK,
|
||||
pod_id,
|
||||
nx=True,
|
||||
ex=TENANT_LOCK_EXPIRATION,
|
||||
)
|
||||
if not acquired:
|
||||
logger.debug(f"Another pod holds the lock for tenant {tenant_id}")
|
||||
continue
|
||||
|
||||
logger.debug(f"Acquired lock for tenant {tenant_id}")
|
||||
self.tenant_ids.add(tenant_id)
|
||||
|
||||
for tenant_id in self.tenant_ids:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
|
||||
tenant_id or POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try:
|
||||
logger.debug(
|
||||
f"Setting tenant ID context variable for tenant {tenant_id}"
|
||||
)
|
||||
slack_bot_tokens = fetch_tokens()
|
||||
logger.debug(f"Fetched Slack bot tokens for tenant {tenant_id}")
|
||||
logger.debug(
|
||||
f"Reset tenant ID context variable for tenant {tenant_id}"
|
||||
)
|
||||
|
||||
if not slack_bot_tokens:
|
||||
logger.debug(
|
||||
f"No Slack bot token found for tenant {tenant_id}"
|
||||
)
|
||||
if tenant_id in self.socket_clients:
|
||||
asyncio.run(self.socket_clients[tenant_id].close())
|
||||
del self.socket_clients[tenant_id]
|
||||
del self.slack_bot_tokens[tenant_id]
|
||||
continue
|
||||
|
||||
if (
|
||||
tenant_id not in self.slack_bot_tokens
|
||||
or slack_bot_tokens != self.slack_bot_tokens[tenant_id]
|
||||
):
|
||||
if tenant_id in self.slack_bot_tokens:
|
||||
logger.info(
|
||||
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
|
||||
)
|
||||
else:
|
||||
search_settings = get_current_search_settings(
|
||||
db_session
|
||||
)
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
warm_up_bi_encoder(embedding_model=embedding_model)
|
||||
|
||||
self.slack_bot_tokens[tenant_id] = slack_bot_tokens
|
||||
|
||||
if self.socket_clients.get(tenant_id):
|
||||
asyncio.run(self.socket_clients[tenant_id].close())
|
||||
|
||||
self.start_socket_client(tenant_id, slack_bot_tokens)
|
||||
|
||||
except KvKeyNotFoundError:
|
||||
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
|
||||
if self.socket_clients.get(tenant_id):
|
||||
asyncio.run(self.socket_clients[tenant_id].close())
|
||||
del self.socket_clients[tenant_id]
|
||||
del self.slack_bot_tokens[tenant_id]
|
||||
except Exception as e:
|
||||
logger.exception(f"Error handling tenant {tenant_id}: {e}")
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
def send_heartbeats(self) -> None:
|
||||
current_time = int(time.time())
|
||||
logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} tenants")
|
||||
for tenant_id in self.tenant_ids:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
heartbeat_key = (
|
||||
f"{DanswerRedisLocks.SLACK_BOT_HEARTBEAT_PREFIX}:{self.pod_id}"
|
||||
)
|
||||
redis_client.set(
|
||||
heartbeat_key, current_time, ex=TENANT_HEARTBEAT_EXPIRATION
|
||||
)
|
||||
|
||||
def start_socket_client(
|
||||
self, tenant_id: str | None, slack_bot_tokens: SlackBotTokens
|
||||
) -> None:
|
||||
logger.info(f"Starting socket client for tenant {tenant_id}")
|
||||
socket_client = _get_socket_client(slack_bot_tokens, tenant_id)
|
||||
|
||||
# Append the event handler
|
||||
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
|
||||
|
||||
# Establish a WebSocket connection to the Socket Mode servers
|
||||
logger.info(f"Connecting socket client for tenant {tenant_id}")
|
||||
socket_client.connect()
|
||||
self.socket_clients[tenant_id] = socket_client
|
||||
logger.info(f"Started SocketModeClient for tenant {tenant_id}")
|
||||
|
||||
def stop_socket_clients(self) -> None:
|
||||
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
|
||||
for tenant_id, client in self.socket_clients.items():
|
||||
if client:
|
||||
asyncio.run(client.close())
|
||||
logger.info(f"Stopped SocketModeClient for tenant {tenant_id}")
|
||||
|
||||
def shutdown(self, signum: int | None, frame: FrameType | None) -> None:
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
logger.info("Shutting down gracefully")
|
||||
self.running = False
|
||||
self._shutdown_event.set()
|
||||
|
||||
# Stop all socket clients
|
||||
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
|
||||
self.stop_socket_clients()
|
||||
|
||||
# Release locks for all tenants
|
||||
logger.info(f"Releasing locks for {len(self.tenant_ids)} tenants")
|
||||
for tenant_id in self.tenant_ids:
|
||||
try:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(DanswerRedisLocks.SLACK_BOT_LOCK)
|
||||
logger.info(f"Released lock for tenant {tenant_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error releasing lock for tenant {tenant_id}: {e}")
|
||||
|
||||
# Wait for background threads to finish (with timeout)
|
||||
logger.info("Waiting for background threads to finish...")
|
||||
self.acquire_thread.join(timeout=5)
|
||||
self.heartbeat_thread.join(timeout=5)
|
||||
|
||||
logger.info("Shutdown complete")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -> bool:
|
||||
"""True to keep going, False to ignore this Slack request"""
|
||||
if req.type == "events_api":
|
||||
@@ -418,7 +172,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
|
||||
message_subtype = event.get("subtype")
|
||||
if message_subtype not in [None, "file_share"]:
|
||||
channel_specific_logger.info(
|
||||
f"Ignoring message with subtype '{message_subtype}' since it is a special message type"
|
||||
f"Ignoring message with subtype '{message_subtype}' since is is a special message type"
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -493,7 +247,7 @@ def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) ->
|
||||
)
|
||||
|
||||
query_event_id, _, _ = decompose_action_id(feedback_id)
|
||||
logger.info(f"Successfully handled QA feedback for event: {query_event_id}")
|
||||
logger.notice(f"Successfully handled QA feedback for event: {query_event_id}")
|
||||
|
||||
|
||||
def build_request_details(
|
||||
@@ -515,14 +269,14 @@ def build_request_details(
|
||||
msg = remove_danswer_bot_tag(msg, client=client.web_client)
|
||||
|
||||
if DANSWER_BOT_REPHRASE_MESSAGE:
|
||||
logger.info(f"Rephrasing Slack message. Original message: {msg}")
|
||||
logger.notice(f"Rephrasing Slack message. Original message: {msg}")
|
||||
try:
|
||||
msg = rephrase_slack_message(msg)
|
||||
logger.info(f"Rephrased message: {msg}")
|
||||
logger.notice(f"Rephrased message: {msg}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error while trying to rephrase the Slack message: {e}")
|
||||
else:
|
||||
logger.info(f"Received Slack message: {msg}")
|
||||
logger.notice(f"Received Slack message: {msg}")
|
||||
|
||||
if tagged:
|
||||
logger.debug("User tagged DanswerBot")
|
||||
@@ -723,21 +477,94 @@ def _get_socket_client(
|
||||
)
|
||||
|
||||
|
||||
def _initialize_socket_client(socket_client: TenantSocketModeClient) -> None:
|
||||
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
|
||||
|
||||
# Establish a WebSocket connection to the Socket Mode servers
|
||||
logger.notice(f"Listening for messages from Slack {socket_client.tenant_id }...")
|
||||
socket_client.connect()
|
||||
|
||||
|
||||
# Follow the guide (https://docs.danswer.dev/slack_bot_setup) to set up
|
||||
# the slack bot in your workspace, and then add the bot to any channels you want to
|
||||
# try and answer questions for. Running this file will setup Danswer to listen to all
|
||||
# messages in those channels and attempt to answer them. As of now, it will only respond
|
||||
# to messages sent directly in the channel - it will not respond to messages sent within a
|
||||
# thread.
|
||||
#
|
||||
# NOTE: we are using Web Sockets so that you can run this from within a firewalled VPC
|
||||
# without issue.
|
||||
if __name__ == "__main__":
|
||||
# Initialize the tenant handler which will manage tenant connections
|
||||
logger.info("Starting SlackbotHandler")
|
||||
tenant_handler = SlackbotHandler()
|
||||
slack_bot_tokens: dict[str | None, SlackBotTokens] = {}
|
||||
socket_clients: dict[str | None, TenantSocketModeClient] = {}
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
|
||||
download_nltk_data()
|
||||
|
||||
try:
|
||||
# Keep the main thread alive
|
||||
while tenant_handler.running:
|
||||
time.sleep(1)
|
||||
while True:
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids() # Function to retrieve all tenant IDs
|
||||
|
||||
except Exception:
|
||||
logger.exception("Fatal error in main thread")
|
||||
tenant_handler.shutdown(None, None)
|
||||
for tenant_id in tenant_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id or "public")
|
||||
latest_slack_bot_tokens = fetch_tokens()
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
if (
|
||||
tenant_id not in slack_bot_tokens
|
||||
or latest_slack_bot_tokens != slack_bot_tokens[tenant_id]
|
||||
):
|
||||
if tenant_id in slack_bot_tokens:
|
||||
logger.notice(
|
||||
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
|
||||
)
|
||||
else:
|
||||
# Initial setup for this tenant
|
||||
search_settings = get_current_search_settings(
|
||||
db_session
|
||||
)
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
warm_up_bi_encoder(embedding_model=embedding_model)
|
||||
|
||||
slack_bot_tokens[tenant_id] = latest_slack_bot_tokens
|
||||
|
||||
# potentially may cause a message to be dropped, but it is complicated
|
||||
# to avoid + (1) if the user is changing tokens, they are likely okay with some
|
||||
# "migration downtime" and (2) if a single message is lost it is okay
|
||||
# as this should be a very rare occurrence
|
||||
if tenant_id in socket_clients:
|
||||
socket_clients[tenant_id].close()
|
||||
|
||||
socket_client = _get_socket_client(
|
||||
latest_slack_bot_tokens, tenant_id
|
||||
)
|
||||
|
||||
# Initialize socket client for this tenant. Each tenant has its own
|
||||
# socket client, allowing for multiple concurrent connections (one
|
||||
# per tenant) with the tenant ID wrapped in the socket model client.
|
||||
# Each `connect` stores websocket connection in a separate thread.
|
||||
_initialize_socket_client(socket_client)
|
||||
|
||||
socket_clients[tenant_id] = socket_client
|
||||
|
||||
except KvKeyNotFoundError:
|
||||
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
|
||||
if tenant_id in socket_clients:
|
||||
socket_clients[tenant_id].disconnect()
|
||||
del socket_clients[tenant_id]
|
||||
del slack_bot_tokens[tenant_id]
|
||||
|
||||
# Wait before checking for updates
|
||||
Event().wait(timeout=60)
|
||||
|
||||
except Exception:
|
||||
logger.exception("An error occurred outside of main event loop")
|
||||
time.sleep(60)
|
||||
|
||||
@@ -14,7 +14,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.invited_users import get_invited_users
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.db.api_key import get_api_key_email_pattern
|
||||
from danswer.db.engine import get_async_session
|
||||
from danswer.db.engine import get_async_session_with_tenant
|
||||
from danswer.db.models import AccessToken
|
||||
@@ -36,16 +35,12 @@ def get_default_admin_user_emails() -> list[str]:
|
||||
return get_default_admin_user_emails_fn()
|
||||
|
||||
|
||||
def get_total_users_count(db_session: Session) -> int:
|
||||
def get_total_users(db_session: Session) -> int:
|
||||
"""
|
||||
Returns the total number of users in the system.
|
||||
This is the sum of users and invited users.
|
||||
"""
|
||||
user_count = (
|
||||
db_session.query(User)
|
||||
.filter(~User.email.endswith(get_api_key_email_pattern())) # type: ignore
|
||||
.count()
|
||||
)
|
||||
user_count = db_session.query(User).count()
|
||||
invited_users = len(get_invited_users())
|
||||
return user_count + invited_users
|
||||
|
||||
|
||||
@@ -388,7 +388,7 @@ def get_chat_messages_by_session(
|
||||
)
|
||||
|
||||
if prefetch_tool_calls:
|
||||
stmt = stmt.options(joinedload(ChatMessage.tool_call))
|
||||
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
|
||||
result = db_session.scalars(stmt).unique().all()
|
||||
else:
|
||||
result = db_session.scalars(stmt).all()
|
||||
@@ -474,7 +474,7 @@ def create_new_chat_message(
|
||||
alternate_assistant_id: int | None = None,
|
||||
# Maps the citation number [n] to the DB SearchDoc
|
||||
citations: dict[int, int] | None = None,
|
||||
tool_call: ToolCall | None = None,
|
||||
tool_calls: list[ToolCall] | None = None,
|
||||
commit: bool = True,
|
||||
reserved_message_id: int | None = None,
|
||||
overridden_model: str | None = None,
|
||||
@@ -494,7 +494,7 @@ def create_new_chat_message(
|
||||
existing_message.message_type = message_type
|
||||
existing_message.citations = citations
|
||||
existing_message.files = files
|
||||
existing_message.tool_call = tool_call
|
||||
existing_message.tool_calls = tool_calls if tool_calls else []
|
||||
existing_message.error = error
|
||||
existing_message.alternate_assistant_id = alternate_assistant_id
|
||||
existing_message.overridden_model = overridden_model
|
||||
@@ -513,7 +513,7 @@ def create_new_chat_message(
|
||||
message_type=message_type,
|
||||
citations=citations,
|
||||
files=files,
|
||||
tool_call=tool_call,
|
||||
tool_calls=tool_calls if tool_calls else [],
|
||||
error=error,
|
||||
alternate_assistant_id=alternate_assistant_id,
|
||||
overridden_model=overridden_model,
|
||||
@@ -749,13 +749,14 @@ def translate_db_message_to_chat_message_detail(
|
||||
time_sent=chat_message.time_sent,
|
||||
citations=chat_message.citations,
|
||||
files=chat_message.files or [],
|
||||
tool_call=ToolCallFinalResult(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
)
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
tool_calls=[
|
||||
ToolCallFinalResult(
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_args=tool_call.tool_arguments,
|
||||
tool_result=tool_call.tool_result,
|
||||
)
|
||||
for tool_call in chat_message.tool_calls
|
||||
],
|
||||
alternate_assistant_id=chat_message.alternate_assistant_id,
|
||||
overridden_model=chat_message.overridden_model,
|
||||
)
|
||||
|
||||
@@ -25,8 +25,8 @@ from danswer.db.models import UserGroup__ConnectorCredentialPair
|
||||
from danswer.db.models import UserRole
|
||||
from danswer.server.models import StatusResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit
|
||||
from ee.danswer.external_permissions.sync_params import check_if_valid_sync_source
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -351,11 +351,7 @@ def add_credential_to_connector(
|
||||
raise HTTPException(status_code=404, detail="Connector does not exist")
|
||||
|
||||
if access_type == AccessType.SYNC:
|
||||
if not fetch_ee_implementation_or_noop(
|
||||
"danswer.external_permissions.sync_params",
|
||||
"check_if_valid_sync_source",
|
||||
noop_return_value=True,
|
||||
)(connector.source):
|
||||
if not check_if_valid_sync_source(connector.source):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Connector of type {connector.source} does not support SYNC access type",
|
||||
@@ -442,10 +438,7 @@ def remove_credential_from_connector(
|
||||
)
|
||||
|
||||
if association is not None:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"danswer.db.external_perm",
|
||||
"delete_user__ext_group_for_cc_pair__no_commit",
|
||||
)(
|
||||
delete_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=association.id,
|
||||
)
|
||||
|
||||
@@ -10,7 +10,10 @@ from sqlalchemy.sql.expression import or_
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.google_utils.shared_constants import (
|
||||
from danswer.connectors.gmail.constants import (
|
||||
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
@@ -421,15 +424,25 @@ def cleanup_google_drive_credentials(db_session: Session) -> None:
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_service_account_credentials(
|
||||
user: User | None, db_session: Session, source: DocumentSource
|
||||
def delete_gmail_service_account_credentials(
|
||||
user: User | None, db_session: Session
|
||||
) -> None:
|
||||
credentials = fetch_credentials(db_session=db_session, user=user)
|
||||
for credential in credentials:
|
||||
if (
|
||||
credential.credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY)
|
||||
and credential.source == source
|
||||
if credential.credential_json.get(
|
||||
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
|
||||
):
|
||||
db_session.delete(credential)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_google_drive_service_account_credentials(
|
||||
user: User | None, db_session: Session
|
||||
) -> None:
|
||||
credentials = fetch_credentials(db_session=db_session, user=user)
|
||||
for credential in credentials:
|
||||
if credential.credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY):
|
||||
db_session.delete(credential)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@@ -29,7 +29,6 @@ from danswer.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
|
||||
from danswer.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
|
||||
from danswer.configs.app_configs import POSTGRES_DB
|
||||
from danswer.configs.app_configs import POSTGRES_HOST
|
||||
from danswer.configs.app_configs import POSTGRES_IDLE_SESSIONS_TIMEOUT
|
||||
from danswer.configs.app_configs import POSTGRES_PASSWORD
|
||||
from danswer.configs.app_configs import POSTGRES_POOL_PRE_PING
|
||||
from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE
|
||||
@@ -38,10 +37,10 @@ from danswer.configs.app_configs import POSTGRES_USER
|
||||
from danswer.configs.app_configs import USER_AUTH_SECRET
|
||||
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -310,12 +309,8 @@ async def get_async_session_with_tenant(
|
||||
try:
|
||||
# Set the search_path to the tenant's schema
|
||||
await session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
await session.execute(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error setting search_path.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting search_path: {str(e)}")
|
||||
# You can choose to re-raise the exception or handle it
|
||||
# Here, we'll re-raise to prevent proceeding with an incorrect session
|
||||
raise
|
||||
@@ -323,77 +318,47 @@ async def get_async_session_with_tenant(
|
||||
yield session
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_default_tenant() -> Generator[Session, None, None]:
|
||||
"""
|
||||
Get a database session using the current tenant ID from the context variable.
|
||||
"""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
with get_session_with_tenant(tenant_id) as session:
|
||||
yield session
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_tenant(
|
||||
tenant_id: str | None = None,
|
||||
) -> Generator[Session, None, None]:
|
||||
"""
|
||||
Generate a database session for a specific tenant.
|
||||
|
||||
This function:
|
||||
1. Sets the database schema to the specified tenant's schema.
|
||||
2. Preserves the tenant ID across the session.
|
||||
3. Reverts to the previous tenant ID after the session is closed.
|
||||
4. Uses the default schema if no tenant ID is provided.
|
||||
"""
|
||||
"""Generate a database session bound to a connection with the appropriate tenant schema set."""
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
# Store the previous tenant ID
|
||||
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
if tenant_id is None:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
else:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
event.listen(engine, "checkout", set_search_path_on_checkout)
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
|
||||
try:
|
||||
# Establish a raw connection
|
||||
with engine.connect() as connection:
|
||||
# Access the raw DBAPI connection and set the search_path
|
||||
dbapi_connection = connection.connection
|
||||
# Establish a raw connection
|
||||
with engine.connect() as connection:
|
||||
# Access the raw DBAPI connection and set the search_path
|
||||
dbapi_connection = connection.connection
|
||||
|
||||
# Set the search_path outside of any transaction
|
||||
cursor = dbapi_connection.cursor()
|
||||
# Set the search_path outside of any transaction
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute(f'SET search_path = "{tenant_id}"')
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
# Bind the session to the connection
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
try:
|
||||
cursor.execute(f'SET search_path = "{tenant_id}"')
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
cursor.execute(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
yield session
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
# Bind the session to the connection
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
# Reset search_path to default after the session is used
|
||||
if MULTI_TENANT:
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute('SET search_path TO "$user", public')
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
finally:
|
||||
# Restore the previous tenant ID
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(previous_tenant_id)
|
||||
# Reset search_path to default after the session is used
|
||||
if MULTI_TENANT:
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute('SET search_path TO "$user", public')
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
|
||||
def set_search_path_on_checkout(
|
||||
|
||||
@@ -219,7 +219,7 @@ def mark_attempt_partially_succeeded(
|
||||
|
||||
|
||||
def mark_attempt_failed(
|
||||
index_attempt_id: int,
|
||||
index_attempt: IndexAttempt,
|
||||
db_session: Session,
|
||||
failure_reason: str = "Unknown",
|
||||
full_exception_trace: str | None = None,
|
||||
@@ -227,7 +227,7 @@ def mark_attempt_failed(
|
||||
try:
|
||||
attempt = db_session.execute(
|
||||
select(IndexAttempt)
|
||||
.where(IndexAttempt.id == index_attempt_id)
|
||||
.where(IndexAttempt.id == index_attempt.id)
|
||||
.with_for_update()
|
||||
).scalar_one()
|
||||
|
||||
|
||||
@@ -135,9 +135,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
hidden_assistants: Mapped[list[int]] = mapped_column(
|
||||
postgresql.JSONB(), nullable=False, default=[]
|
||||
)
|
||||
recent_assistants: Mapped[list[dict]] = mapped_column(
|
||||
postgresql.JSONB(), nullable=False, default=list, server_default="[]"
|
||||
)
|
||||
|
||||
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
|
||||
TIMESTAMPAware(timezone=True), nullable=True
|
||||
@@ -737,10 +734,9 @@ class IndexAttempt(Base):
|
||||
full_exception_trace: Mapped[str | None] = mapped_column(Text, default=None)
|
||||
# Nullable because in the past, we didn't allow swapping out embedding models live
|
||||
search_settings_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("search_settings.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
ForeignKey("search_settings.id"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
@@ -760,7 +756,7 @@ class IndexAttempt(Base):
|
||||
"ConnectorCredentialPair", back_populates="index_attempts"
|
||||
)
|
||||
|
||||
search_settings: Mapped[SearchSettings | None] = relationship(
|
||||
search_settings: Mapped[SearchSettings] = relationship(
|
||||
"SearchSettings", back_populates="index_attempts"
|
||||
)
|
||||
|
||||
@@ -921,15 +917,10 @@ class ToolCall(Base):
|
||||
tool_arguments: Mapped[dict[str, JSON_ro]] = mapped_column(postgresql.JSONB())
|
||||
tool_result: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
|
||||
|
||||
message_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("chat_message.id"), nullable=False
|
||||
)
|
||||
message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
|
||||
|
||||
# Update the relationship
|
||||
message: Mapped["ChatMessage"] = relationship(
|
||||
"ChatMessage",
|
||||
back_populates="tool_call",
|
||||
uselist=False,
|
||||
"ChatMessage", back_populates="tool_calls"
|
||||
)
|
||||
|
||||
|
||||
@@ -1060,13 +1051,12 @@ class ChatMessage(Base):
|
||||
secondary=ChatMessage__SearchDoc.__table__,
|
||||
back_populates="chat_messages",
|
||||
)
|
||||
|
||||
tool_call: Mapped["ToolCall"] = relationship(
|
||||
# NOTE: Should always be attached to the `assistant` message.
|
||||
# represents the tool calls used to generate this message
|
||||
tool_calls: Mapped[list["ToolCall"]] = relationship(
|
||||
"ToolCall",
|
||||
back_populates="message",
|
||||
uselist=False,
|
||||
)
|
||||
|
||||
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
|
||||
"StandardAnswer",
|
||||
secondary=ChatMessage__StandardAnswer.__table__,
|
||||
@@ -1324,6 +1314,7 @@ class StarterMessage(TypedDict):
|
||||
in Postgres"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
message: str
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
from danswer.db.engine import get_session_with_default_tenant
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.llm import fetch_embedding_provider
|
||||
from danswer.db.models import CloudEmbeddingProvider
|
||||
from danswer.db.models import IndexAttempt
|
||||
@@ -152,7 +152,7 @@ def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
|
||||
|
||||
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
|
||||
if db_session is None:
|
||||
with get_session_with_default_tenant() as db_session:
|
||||
with get_session_with_tenant() as db_session:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
else:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
@@ -14,6 +14,7 @@ from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.search_settings import update_search_settings_status
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -22,14 +23,7 @@ logger = setup_logger()
|
||||
def check_index_swap(db_session: Session) -> SearchSettings | None:
|
||||
"""Get count of cc-pairs and count of successful index_attempts for the
|
||||
new model grouped by connector + credential, if it's the same, then assume
|
||||
new index is done building. If so, swap the indices and expire the old one.
|
||||
|
||||
Returns None if search settings did not change, or the old search settings if they
|
||||
did change.
|
||||
"""
|
||||
|
||||
old_search_settings = None
|
||||
|
||||
new index is done building. If so, swap the indices and expire the old one."""
|
||||
# Default CC-pair created for Ingestion API unused here
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session)
|
||||
cc_pair_count = max(len(all_cc_pairs) - 1, 0)
|
||||
@@ -49,9 +43,9 @@ def check_index_swap(db_session: Session) -> SearchSettings | None:
|
||||
|
||||
if cc_pair_count == 0 or cc_pair_count == unique_cc_indexings:
|
||||
# Swap indices
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
now_old_search_settings = get_current_search_settings(db_session)
|
||||
update_search_settings_status(
|
||||
search_settings=current_search_settings,
|
||||
search_settings=now_old_search_settings,
|
||||
new_status=IndexModelStatus.PAST,
|
||||
db_session=db_session,
|
||||
)
|
||||
@@ -73,6 +67,6 @@ def check_index_swap(db_session: Session) -> SearchSettings | None:
|
||||
for cc_pair in all_cc_pairs:
|
||||
resync_cc_pair(cc_pair, db_session=db_session)
|
||||
|
||||
old_search_settings = current_search_settings
|
||||
|
||||
return old_search_settings
|
||||
if MULTI_TENANT:
|
||||
return now_old_search_settings
|
||||
return None
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import TokenRateLimitScope
|
||||
from danswer.db.models import TokenRateLimit
|
||||
from danswer.db.models import TokenRateLimit__UserGroup
|
||||
from danswer.server.token_rate_limits.models import TokenRateLimitArgs
|
||||
|
||||
|
||||
def fetch_all_user_token_rate_limits(
|
||||
db_session: Session,
|
||||
enabled_only: bool = False,
|
||||
ordered: bool = True,
|
||||
) -> Sequence[TokenRateLimit]:
|
||||
query = select(TokenRateLimit).where(
|
||||
TokenRateLimit.scope == TokenRateLimitScope.USER
|
||||
)
|
||||
|
||||
if enabled_only:
|
||||
query = query.where(TokenRateLimit.enabled.is_(True))
|
||||
|
||||
if ordered:
|
||||
query = query.order_by(TokenRateLimit.created_at.desc())
|
||||
|
||||
return db_session.scalars(query).all()
|
||||
|
||||
|
||||
def fetch_all_global_token_rate_limits(
|
||||
db_session: Session,
|
||||
enabled_only: bool = False,
|
||||
ordered: bool = True,
|
||||
) -> Sequence[TokenRateLimit]:
|
||||
query = select(TokenRateLimit).where(
|
||||
TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
|
||||
)
|
||||
|
||||
if enabled_only:
|
||||
query = query.where(TokenRateLimit.enabled.is_(True))
|
||||
|
||||
if ordered:
|
||||
query = query.order_by(TokenRateLimit.created_at.desc())
|
||||
|
||||
token_rate_limits = db_session.scalars(query).all()
|
||||
return token_rate_limits
|
||||
|
||||
|
||||
def insert_user_token_rate_limit(
|
||||
db_session: Session,
|
||||
token_rate_limit_settings: TokenRateLimitArgs,
|
||||
) -> TokenRateLimit:
|
||||
token_limit = TokenRateLimit(
|
||||
enabled=token_rate_limit_settings.enabled,
|
||||
token_budget=token_rate_limit_settings.token_budget,
|
||||
period_hours=token_rate_limit_settings.period_hours,
|
||||
scope=TokenRateLimitScope.USER,
|
||||
)
|
||||
db_session.add(token_limit)
|
||||
db_session.commit()
|
||||
|
||||
return token_limit
|
||||
|
||||
|
||||
def insert_global_token_rate_limit(
|
||||
db_session: Session,
|
||||
token_rate_limit_settings: TokenRateLimitArgs,
|
||||
) -> TokenRateLimit:
|
||||
token_limit = TokenRateLimit(
|
||||
enabled=token_rate_limit_settings.enabled,
|
||||
token_budget=token_rate_limit_settings.token_budget,
|
||||
period_hours=token_rate_limit_settings.period_hours,
|
||||
scope=TokenRateLimitScope.GLOBAL,
|
||||
)
|
||||
db_session.add(token_limit)
|
||||
db_session.commit()
|
||||
|
||||
return token_limit
|
||||
|
||||
|
||||
def update_token_rate_limit(
|
||||
db_session: Session,
|
||||
token_rate_limit_id: int,
|
||||
token_rate_limit_settings: TokenRateLimitArgs,
|
||||
) -> TokenRateLimit:
|
||||
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
|
||||
if token_limit is None:
|
||||
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")
|
||||
|
||||
token_limit.enabled = token_rate_limit_settings.enabled
|
||||
token_limit.token_budget = token_rate_limit_settings.token_budget
|
||||
token_limit.period_hours = token_rate_limit_settings.period_hours
|
||||
db_session.commit()
|
||||
|
||||
return token_limit
|
||||
|
||||
|
||||
def delete_token_rate_limit(
|
||||
db_session: Session,
|
||||
token_rate_limit_id: int,
|
||||
) -> None:
|
||||
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
|
||||
if token_limit is None:
|
||||
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")
|
||||
|
||||
db_session.query(TokenRateLimit__UserGroup).filter(
|
||||
TokenRateLimit__UserGroup.rate_limit_id == token_rate_limit_id
|
||||
).delete()
|
||||
|
||||
db_session.delete(token_limit)
|
||||
db_session.commit()
|
||||
@@ -24,13 +24,6 @@ def get_tool_by_id(tool_id: int, db_session: Session) -> Tool:
|
||||
return tool
|
||||
|
||||
|
||||
def get_tool_by_name(tool_name: str, db_session: Session) -> Tool:
|
||||
tool = db_session.scalar(select(Tool).where(Tool.name == tool_name))
|
||||
if not tool:
|
||||
raise ValueError("Tool by specified name does not exist")
|
||||
return tool
|
||||
|
||||
|
||||
def create_tool(
|
||||
name: str,
|
||||
description: str | None,
|
||||
@@ -44,7 +37,7 @@ def create_tool(
|
||||
description=description,
|
||||
in_code_tool_id=None,
|
||||
openapi_schema=openapi_schema,
|
||||
custom_headers=[header.model_dump() for header in custom_headers]
|
||||
custom_headers=[header.dict() for header in custom_headers]
|
||||
if custom_headers
|
||||
else [],
|
||||
user_id=user_id,
|
||||
|
||||
@@ -147,7 +147,7 @@ class VespaIndex(DocumentIndex):
|
||||
return None
|
||||
|
||||
deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate"
|
||||
logger.notice(f"Deploying Vespa application package to {deploy_url}")
|
||||
logger.info(f"Deploying Vespa application package to {deploy_url}")
|
||||
|
||||
vespa_schema_path = os.path.join(
|
||||
os.getcwd(), "danswer", "document_index", "vespa", "app_config"
|
||||
|
||||
@@ -13,7 +13,6 @@ class ChatFileType(str, Enum):
|
||||
DOC = "document"
|
||||
# Plain text only contain the text
|
||||
PLAIN_TEXT = "plain_text"
|
||||
CSV = "csv"
|
||||
|
||||
|
||||
class FileDescriptor(TypedDict):
|
||||
|
||||
@@ -8,13 +8,12 @@ import requests
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
def load_chat_file(
|
||||
@@ -53,11 +52,11 @@ def load_all_chat_files(
|
||||
return files
|
||||
|
||||
|
||||
def save_file_from_url(url: str, tenant_id: str) -> str:
|
||||
def save_file_from_url(url: str) -> str:
|
||||
"""NOTE: using multiple sessions here, since this is often called
|
||||
using multithreading. In practice, sharing a session has resulted in
|
||||
weird errors."""
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_context_manager() as db_session:
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -76,10 +75,7 @@ def save_file_from_url(url: str, tenant_id: str) -> str:
|
||||
|
||||
|
||||
def save_files_from_urls(urls: list[str]) -> list[str]:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
funcs: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [
|
||||
(save_file_from_url, (url, tenant_id)) for url in urls
|
||||
(save_file_from_url, (url,)) for url in urls
|
||||
]
|
||||
# Must pass in tenant_id here, since this is called by multithreading
|
||||
return run_functions_tuples_in_parallel(funcs)
|
||||
|
||||
@@ -16,9 +16,9 @@ from danswer.key_value_store.interface import KeyValueStore
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -1,44 +1,72 @@
|
||||
import itertools
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.chat.chat_utils import llm_doc_from_inference_section
|
||||
from danswer.chat.models import AnswerQuestionPossibleReturn
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
|
||||
from danswer.file_store.utils import InMemoryChatFile
|
||||
from danswer.llm.answering.llm_response_handler import LLMCall
|
||||
from danswer.llm.answering.llm_response_handler import LLMResponseHandlerManager
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.models import StreamProcessor
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.answering.prompts.build import default_build_system_message
|
||||
from danswer.llm.answering.prompts.build import default_build_user_message
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
AnswerResponseHandler,
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
build_citations_system_message,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
CitationResponseHandler,
|
||||
from danswer.llm.answering.prompts.citations_prompt import build_citations_user_message
|
||||
from danswer.llm.answering.prompts.quotes_prompt import build_quotes_user_message
|
||||
from danswer.llm.answering.stream_processing.citation_processing import (
|
||||
build_citation_processor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
DummyAnswerResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
QuotesResponseHandler,
|
||||
from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
build_quotes_processor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.llm.answering.stream_processing.utils import map_document_id_order
|
||||
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.interfaces import ToolChoiceOptions
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.tools.custom.custom_tool_prompt_builder import (
|
||||
build_user_message_for_custom_tool_for_non_tool_calling_llm,
|
||||
)
|
||||
from danswer.tools.force import filter_tools_for_force_tool_use
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
from danswer.tools.images.prompt import build_image_generation_user_prompt
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
||||
from danswer.tools.message import build_tool_message
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
|
||||
from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
||||
from danswer.tools.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_runner import (
|
||||
check_which_tools_should_run_for_non_tool_calling_llm,
|
||||
)
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.tools.tool_runner import ToolCallKickoff
|
||||
from danswer.tools.tool_runner import ToolRunner
|
||||
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -46,9 +74,29 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_answer_stream_processor(
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
answer_style_configs: AnswerStyleConfig,
|
||||
) -> StreamProcessor:
|
||||
if answer_style_configs.citation_config:
|
||||
return build_citation_processor(
|
||||
context_docs=context_docs, doc_id_to_rank_map=doc_id_to_rank_map
|
||||
)
|
||||
if answer_style_configs.quotes_config:
|
||||
return build_quotes_processor(
|
||||
context_docs=context_docs, is_json_prompt=not (QA_PROMPT_OVERRIDE == "weak")
|
||||
)
|
||||
|
||||
raise RuntimeError("Not implemented yet")
|
||||
|
||||
|
||||
AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse]
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class Answer:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -88,6 +136,8 @@ class Answer:
|
||||
self.tools = tools or []
|
||||
self.force_use_tool = force_use_tool
|
||||
|
||||
self.skip_explicit_tool_calling = skip_explicit_tool_calling
|
||||
|
||||
self.message_history = message_history or []
|
||||
# used for QA flow where we only want to send a single message
|
||||
self.single_message_history = single_message_history
|
||||
@@ -112,141 +162,335 @@ class Answer:
|
||||
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
|
||||
self._is_cancelled = False
|
||||
|
||||
self.using_tool_calling_llm = (
|
||||
explicit_tool_calling_supported(
|
||||
self.llm.config.model_provider, self.llm.config.model_name
|
||||
def _update_prompt_builder_for_search_tool(
|
||||
self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc]
|
||||
) -> None:
|
||||
if self.answer_style_config.citation_config:
|
||||
prompt_builder.update_system_prompt(
|
||||
build_citations_system_message(self.prompt_config)
|
||||
)
|
||||
and not skip_explicit_tool_calling
|
||||
)
|
||||
|
||||
def _get_tools_list(self) -> list[Tool]:
|
||||
if not self.force_use_tool.force_use:
|
||||
return self.tools
|
||||
|
||||
tool = next(
|
||||
(t for t in self.tools if t.name == self.force_use_tool.tool_name), None
|
||||
)
|
||||
if tool is None:
|
||||
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
|
||||
|
||||
logger.info(
|
||||
f"Forcefully using tool='{tool.name}'"
|
||||
+ (
|
||||
f" with args='{self.force_use_tool.args}'"
|
||||
if self.force_use_tool.args is not None
|
||||
else ""
|
||||
)
|
||||
)
|
||||
return [tool]
|
||||
|
||||
def _handle_specified_tool_call(
|
||||
self, llm_calls: list[LLMCall], tool: Tool, tool_args: dict
|
||||
) -> AnswerStream:
|
||||
current_llm_call = llm_calls[-1]
|
||||
|
||||
# make a dummy tool handler
|
||||
tool_handler = ToolResponseHandler([tool])
|
||||
|
||||
dummy_tool_call_chunk = AIMessageChunk(content="")
|
||||
dummy_tool_call_chunk.tool_calls = [
|
||||
ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
|
||||
]
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
tool_handler, DummyAnswerResponseHandler(), self.is_cancelled
|
||||
)
|
||||
yield from response_handler_manager.handle_llm_response(
|
||||
iter([dummy_tool_call_chunk])
|
||||
)
|
||||
|
||||
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
if new_llm_call:
|
||||
yield from self._get_response(llm_calls + [new_llm_call])
|
||||
else:
|
||||
raise RuntimeError("Tool call handler did not return a new LLM call")
|
||||
|
||||
def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
|
||||
current_llm_call = llm_calls[-1]
|
||||
|
||||
# handle the case where no decision has to be made; we simply run the tool
|
||||
if (
|
||||
current_llm_call.force_use_tool.force_use
|
||||
and current_llm_call.force_use_tool.args is not None
|
||||
):
|
||||
tool_name, tool_args = (
|
||||
current_llm_call.force_use_tool.tool_name,
|
||||
current_llm_call.force_use_tool.args,
|
||||
)
|
||||
tool = next(
|
||||
(t for t in current_llm_call.tools if t.name == tool_name), None
|
||||
)
|
||||
if not tool:
|
||||
raise RuntimeError(f"Tool '{tool_name}' not found")
|
||||
|
||||
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
|
||||
return
|
||||
|
||||
# special pre-logic for non-tool calling LLM case
|
||||
if not self.using_tool_calling_llm and current_llm_call.tools:
|
||||
chosen_tool_and_args = (
|
||||
ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
|
||||
current_llm_call, self.llm
|
||||
prompt_builder.update_user_prompt(
|
||||
build_citations_user_message(
|
||||
question=self.question,
|
||||
prompt_config=self.prompt_config,
|
||||
context_docs=final_context_documents,
|
||||
files=self.latest_query_files,
|
||||
all_doc_useful=(
|
||||
self.answer_style_config.citation_config.all_docs_useful
|
||||
if self.answer_style_config.citation_config
|
||||
else False
|
||||
),
|
||||
history_message=self.single_message_history or "",
|
||||
)
|
||||
)
|
||||
elif self.answer_style_config.quotes_config:
|
||||
prompt_builder.update_user_prompt(
|
||||
build_quotes_user_message(
|
||||
question=self.question,
|
||||
context_docs=final_context_documents,
|
||||
history_str=self.single_message_history or "",
|
||||
prompt=self.prompt_config,
|
||||
)
|
||||
)
|
||||
if chosen_tool_and_args:
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
yield from self._handle_specified_tool_call(llm_calls, tool, tool_args)
|
||||
return
|
||||
|
||||
# if we're skipping gen ai answer generation, we should break
|
||||
# out unless we're forcing a tool call. If we don't, we might generate an
|
||||
# answer, which is a no-no!
|
||||
if (
|
||||
self.skip_gen_ai_answer_generation
|
||||
and not current_llm_call.force_use_tool.force_use
|
||||
):
|
||||
def _raw_output_for_explicit_tool_calling_llms(
|
||||
self,
|
||||
) -> Iterator[
|
||||
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
|
||||
]:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
|
||||
tool_call_chunk: AIMessageChunk | None = None
|
||||
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
|
||||
# if we are forcing a tool WITH args specified, we don't need to check which tools to run
|
||||
# / need to generate the args
|
||||
tool_call_chunk = AIMessageChunk(
|
||||
content="",
|
||||
)
|
||||
tool_call_chunk.tool_calls = [
|
||||
{
|
||||
"name": self.force_use_tool.tool_name,
|
||||
"args": self.force_use_tool.args,
|
||||
"id": str(uuid4()),
|
||||
}
|
||||
]
|
||||
else:
|
||||
# if tool calling is supported, first try the raw message
|
||||
# to see if we don't need to use any tools
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
prompt_builder.update_user_prompt(
|
||||
default_build_user_message(
|
||||
self.question, self.prompt_config, self.latest_query_files
|
||||
)
|
||||
)
|
||||
prompt = prompt_builder.build()
|
||||
final_tool_definitions = [
|
||||
tool.tool_definition()
|
||||
for tool in filter_tools_for_force_tool_use(
|
||||
self.tools, self.force_use_tool
|
||||
)
|
||||
]
|
||||
|
||||
for message in self.llm.stream(
|
||||
prompt=prompt,
|
||||
tools=final_tool_definitions if final_tool_definitions else None,
|
||||
tool_choice="required" if self.force_use_tool.force_use else None,
|
||||
structured_response_format=self.answer_style_config.structured_response_format,
|
||||
):
|
||||
if isinstance(message, AIMessageChunk) and (
|
||||
message.tool_call_chunks or message.tool_calls
|
||||
):
|
||||
if tool_call_chunk is None:
|
||||
tool_call_chunk = message
|
||||
else:
|
||||
tool_call_chunk += message # type: ignore
|
||||
else:
|
||||
if message.content:
|
||||
if self.is_cancelled:
|
||||
return
|
||||
yield cast(str, message.content)
|
||||
if (
|
||||
message.additional_kwargs.get("usage_metadata", {}).get("stop")
|
||||
== "length"
|
||||
):
|
||||
yield StreamStopInfo(
|
||||
stop_reason=StreamStopReason.CONTEXT_LENGTH
|
||||
)
|
||||
|
||||
if not tool_call_chunk:
|
||||
return # no tool call needed
|
||||
|
||||
# if we have a tool call, we need to call the tool
|
||||
tool_call_requests = tool_call_chunk.tool_calls
|
||||
for tool_call_request in tool_call_requests:
|
||||
known_tools_by_name = [
|
||||
tool for tool in self.tools if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
if not known_tools_by_name:
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"self.tools: {self.tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
if self.tools:
|
||||
tool = self.tools[0]
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
tool = known_tools_by_name[0]
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.tool_name == tool.name
|
||||
and self.force_use_tool.args
|
||||
else tool_call_request["args"]
|
||||
)
|
||||
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
yield tool_runner.kickoff()
|
||||
yield from tool_runner.tool_responses()
|
||||
|
||||
tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=tool_call_chunk,
|
||||
tool_call_result=build_tool_message(
|
||||
tool_call_request, tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
self._update_prompt_builder_for_search_tool(prompt_builder, [])
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
img_urls = [
|
||||
img_generation_result["url"]
|
||||
for img_generation_result in tool_runner.tool_final_result().tool_result
|
||||
]
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question, img_urls=img_urls
|
||||
)
|
||||
)
|
||||
yield tool_runner.tool_final_result()
|
||||
if not self.skip_gen_ai_answer_generation:
|
||||
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
|
||||
|
||||
yield from self._process_llm_stream(
|
||||
prompt=prompt,
|
||||
# as of now, we don't support multiple tool calls in sequence, which is why
|
||||
# we don't need to pass this in here
|
||||
# tools=[tool.tool_definition() for tool in self.tools],
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
# set up "handlers" to listen to the LLM response stream and
|
||||
# feed back the processed results + handle tool call requests
|
||||
# + figure out what the next LLM call should be
|
||||
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
||||
# This method processes the LLM stream and yields the content or stop information
|
||||
def _process_llm_stream(
|
||||
self,
|
||||
prompt: Any,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
) -> Iterator[str | StreamStopInfo]:
|
||||
for message in self.llm.stream(
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
structured_response_format=self.answer_style_config.structured_response_format,
|
||||
):
|
||||
if isinstance(message, AIMessageChunk):
|
||||
if message.content:
|
||||
if self.is_cancelled:
|
||||
return StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
yield cast(str, message.content)
|
||||
|
||||
search_result = SearchTool.get_search_result(current_llm_call) or []
|
||||
if (
|
||||
message.additional_kwargs.get("usage_metadata", {}).get("stop")
|
||||
== "length"
|
||||
):
|
||||
yield StreamStopInfo(stop_reason=StreamStopReason.CONTEXT_LENGTH)
|
||||
|
||||
answer_handler: AnswerResponseHandler
|
||||
if self.answer_style_config.citation_config:
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=search_result,
|
||||
doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
def _raw_output_for_non_explicit_tool_calling_llms(
|
||||
self,
|
||||
) -> Iterator[
|
||||
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
|
||||
]:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
chosen_tool_and_args: tuple[Tool, dict] | None = None
|
||||
|
||||
if self.force_use_tool.force_use:
|
||||
# if we are forcing a tool, we don't need to check which tools to run
|
||||
tool = next(
|
||||
iter(
|
||||
[
|
||||
tool
|
||||
for tool in self.tools
|
||||
if tool.name == self.force_use_tool.tool_name
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
elif self.answer_style_config.quotes_config:
|
||||
answer_handler = QuotesResponseHandler(
|
||||
context_docs=search_result,
|
||||
if not tool:
|
||||
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
|
||||
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=self.question,
|
||||
history=self.message_history,
|
||||
llm=self.llm,
|
||||
force_run=True,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise RuntimeError(f"Tool '{tool.name}' did not return args")
|
||||
|
||||
chosen_tool_and_args = (tool, tool_args)
|
||||
else:
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools=self.tools,
|
||||
query=self.question,
|
||||
history=self.message_history,
|
||||
llm=self.llm,
|
||||
)
|
||||
|
||||
available_tools_and_args = [
|
||||
(self.tools[ind], args)
|
||||
for ind, args in enumerate(tool_options)
|
||||
if args is not None
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
|
||||
)
|
||||
|
||||
chosen_tool_and_args = (
|
||||
select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args=available_tools_and_args,
|
||||
history=self.message_history,
|
||||
query=self.question,
|
||||
llm=self.llm,
|
||||
)
|
||||
if available_tools_and_args
|
||||
else None
|
||||
)
|
||||
|
||||
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
|
||||
|
||||
if not chosen_tool_and_args:
|
||||
if self.skip_gen_ai_answer_generation:
|
||||
raise ValueError(
|
||||
"skip_gen_ai_answer_generation is True, but no tool was chosen; no answer will be generated"
|
||||
)
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
prompt_builder.update_user_prompt(
|
||||
default_build_user_message(
|
||||
self.question, self.prompt_config, self.latest_query_files
|
||||
)
|
||||
)
|
||||
prompt = prompt_builder.build()
|
||||
yield from self._process_llm_stream(
|
||||
prompt=prompt,
|
||||
tools=None,
|
||||
)
|
||||
return
|
||||
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
yield tool_runner.kickoff()
|
||||
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
final_context_documents = None
|
||||
for response in tool_runner.tool_responses():
|
||||
if response.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
final_context_documents = cast(list[LlmDoc], response.response)
|
||||
yield response
|
||||
|
||||
if final_context_documents is None:
|
||||
raise RuntimeError(
|
||||
f"{tool.name} did not return final context documents"
|
||||
)
|
||||
|
||||
self._update_prompt_builder_for_search_tool(
|
||||
prompt_builder, final_context_documents
|
||||
)
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
img_urls = []
|
||||
for response in tool_runner.tool_responses():
|
||||
if response.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], response.response
|
||||
)
|
||||
img_urls = [img.url for img in img_generation_response]
|
||||
|
||||
yield response
|
||||
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question,
|
||||
img_urls=img_urls,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError("No answer style config provided")
|
||||
prompt_builder.update_user_prompt(
|
||||
HumanMessage(
|
||||
content=build_user_message_for_custom_tool_for_non_tool_calling_llm(
|
||||
self.question,
|
||||
tool.name,
|
||||
*tool_runner.tool_responses(),
|
||||
)
|
||||
)
|
||||
)
|
||||
final = tool_runner.tool_final_result()
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
tool_call_handler, answer_handler, self.is_cancelled
|
||||
)
|
||||
yield final
|
||||
if not self.skip_gen_ai_answer_generation:
|
||||
prompt = prompt_builder.build()
|
||||
|
||||
# DEBUG: good breakpoint
|
||||
stream = self.llm.stream(
|
||||
prompt=current_llm_call.prompt_builder.build(),
|
||||
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
|
||||
tool_choice=(
|
||||
"required"
|
||||
if current_llm_call.tools and current_llm_call.force_use_tool.force_use
|
||||
else None
|
||||
),
|
||||
structured_response_format=self.answer_style_config.structured_response_format,
|
||||
)
|
||||
yield from response_handler_manager.handle_llm_response(stream)
|
||||
|
||||
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
if new_llm_call:
|
||||
yield from self._get_response(llm_calls + [new_llm_call])
|
||||
yield from self._process_llm_stream(prompt=prompt, tools=None)
|
||||
|
||||
@property
|
||||
def processed_streamed_output(self) -> AnswerStream:
|
||||
@@ -254,30 +498,94 @@ class Answer:
|
||||
yield from self._processed_stream
|
||||
return
|
||||
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
user_query=self.question,
|
||||
prompt_config=self.prompt_config,
|
||||
files=self.latest_query_files,
|
||||
),
|
||||
message_history=self.message_history,
|
||||
llm_config=self.llm.config,
|
||||
single_message_history=self.single_message_history,
|
||||
)
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
llm_call = LLMCall(
|
||||
prompt_builder=prompt_builder,
|
||||
tools=self._get_tools_list(),
|
||||
force_use_tool=self.force_use_tool,
|
||||
files=self.latest_query_files,
|
||||
tool_call_info=[],
|
||||
using_tool_calling_llm=self.using_tool_calling_llm,
|
||||
output_generator = (
|
||||
self._raw_output_for_explicit_tool_calling_llms()
|
||||
if explicit_tool_calling_supported(
|
||||
self.llm.config.model_provider, self.llm.config.model_name
|
||||
)
|
||||
and not self.skip_explicit_tool_calling
|
||||
else self._raw_output_for_non_explicit_tool_calling_llms()
|
||||
)
|
||||
|
||||
def _process_stream(
|
||||
stream: Iterator[ToolCallKickoff | ToolResponse | str | StreamStopInfo],
|
||||
) -> AnswerStream:
|
||||
message = None
|
||||
|
||||
# special things we need to keep track of for the SearchTool
|
||||
# raw results that will be displayed to the user
|
||||
search_results: list[LlmDoc] | None = None
|
||||
# processed docs to feed into the LLM
|
||||
final_context_docs: list[LlmDoc] | None = None
|
||||
|
||||
for message in stream:
|
||||
if isinstance(message, ToolCallKickoff) or isinstance(
|
||||
message, ToolCallFinalResult
|
||||
):
|
||||
yield message
|
||||
elif isinstance(message, ToolResponse):
|
||||
if message.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
# We don't need to run section merging in this flow, this variable is only used
|
||||
# below to specify the ordering of the documents for the purpose of matching
|
||||
# citations to the right search documents. The deduplication logic is more lightweight
|
||||
# there and we don't need to do it twice
|
||||
search_results = [
|
||||
llm_doc_from_inference_section(section)
|
||||
for section in cast(
|
||||
SearchResponseSummary, message.response
|
||||
).top_sections
|
||||
]
|
||||
elif message.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
final_context_docs = cast(list[LlmDoc], message.response)
|
||||
yield message
|
||||
|
||||
elif (
|
||||
message.id == SEARCH_DOC_CONTENT_ID
|
||||
and not self._return_contexts
|
||||
):
|
||||
continue
|
||||
|
||||
yield message
|
||||
else:
|
||||
# assumes all tool responses will come first, then the final answer
|
||||
break
|
||||
|
||||
if not self.skip_gen_ai_answer_generation:
|
||||
process_answer_stream_fn = _get_answer_stream_processor(
|
||||
context_docs=final_context_docs or [],
|
||||
# if doc selection is enabled, then search_results will be None,
|
||||
# so we need to use the final_context_docs
|
||||
doc_id_to_rank_map=map_document_id_order(
|
||||
search_results or final_context_docs or []
|
||||
),
|
||||
answer_style_configs=self.answer_style_config,
|
||||
)
|
||||
|
||||
stream_stop_info = None
|
||||
|
||||
def _stream() -> Iterator[str]:
|
||||
nonlocal stream_stop_info
|
||||
for item in itertools.chain([message], stream):
|
||||
if isinstance(item, StreamStopInfo):
|
||||
stream_stop_info = item
|
||||
return
|
||||
|
||||
# this should never happen, but we're seeing weird behavior here so handling for now
|
||||
if not isinstance(item, str):
|
||||
logger.error(
|
||||
f"Received non-string item in answer stream: {item}. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
yield item
|
||||
|
||||
yield from process_answer_stream_fn(_stream())
|
||||
|
||||
if stream_stop_info:
|
||||
yield stream_stop_info
|
||||
|
||||
processed_stream = []
|
||||
for processed_packet in self._get_response([llm_call]):
|
||||
for processed_packet in _process_stream(output_generator):
|
||||
processed_stream.append(processed_packet)
|
||||
yield processed_packet
|
||||
|
||||
@@ -301,6 +609,7 @@ class Answer:
|
||||
|
||||
return citations
|
||||
|
||||
@property
|
||||
def is_cancelled(self) -> bool:
|
||||
if self._is_cancelled:
|
||||
return True
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic.v1 import BaseModel as BaseModel__v1
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
AnswerResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
|
||||
|
||||
|
||||
ResponsePart = (
|
||||
DanswerAnswerPiece
|
||||
| CitationInfo
|
||||
| DanswerQuotes
|
||||
| ToolCallKickoff
|
||||
| ToolResponse
|
||||
| ToolCallFinalResult
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
|
||||
class LLMCall(BaseModel__v1):
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
tools: list[Tool]
|
||||
force_use_tool: ForceUseTool
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
|
||||
using_tool_calling_llm: bool
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class LLMResponseHandlerManager:
|
||||
def __init__(
|
||||
self,
|
||||
tool_handler: "ToolResponseHandler",
|
||||
answer_handler: "AnswerResponseHandler",
|
||||
is_cancelled: Callable[[], bool],
|
||||
):
|
||||
self.tool_handler = tool_handler
|
||||
self.answer_handler = answer_handler
|
||||
self.is_cancelled = is_cancelled
|
||||
|
||||
def handle_llm_response(
|
||||
self,
|
||||
stream: Iterator[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
all_messages: list[BaseMessage] = []
|
||||
for message in stream:
|
||||
if self.is_cancelled():
|
||||
yield StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
return
|
||||
# tool handler doesn't do anything until the full message is received
|
||||
# NOTE: still need to run list() to get this to run
|
||||
list(self.tool_handler.handle_response_part(message, all_messages))
|
||||
yield from self.answer_handler.handle_response_part(message, all_messages)
|
||||
all_messages.append(message)
|
||||
|
||||
# potentially give back all info on the selected tool call + its result
|
||||
yield from self.tool_handler.handle_response_part(None, all_messages)
|
||||
yield from self.answer_handler.handle_response_part(None, all_messages)
|
||||
|
||||
def next_llm_call(self, llm_call: LLMCall) -> LLMCall | None:
|
||||
return self.tool_handler.next_llm_call(llm_call)
|
||||
@@ -33,7 +33,7 @@ class PreviousMessage(BaseModel):
|
||||
token_count: int
|
||||
message_type: MessageType
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
tool_calls: list[ToolCallFinalResult]
|
||||
|
||||
@classmethod
|
||||
def from_chat_message(
|
||||
@@ -51,13 +51,14 @@ class PreviousMessage(BaseModel):
|
||||
for file in available_files
|
||||
if str(file.file_id) in message_file_ids
|
||||
],
|
||||
tool_call=ToolCallFinalResult(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
)
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
tool_calls=[
|
||||
ToolCallFinalResult(
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_args=tool_call.tool_arguments,
|
||||
tool_result=tool_call.tool_result,
|
||||
)
|
||||
for tool_call in chat_message.tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
def to_langchain_msg(self) -> BaseMessage:
|
||||
|
||||
@@ -12,12 +12,12 @@ from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.llm.utils import check_message_tokens
|
||||
from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.llm.utils import translate_history_to_basemessages
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from danswer.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from danswer.prompts.prompt_utils import drop_messages_history_overflow
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
|
||||
|
||||
def default_build_system_message(
|
||||
@@ -54,14 +54,18 @@ def default_build_user_message(
|
||||
|
||||
class AnswerPromptBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
user_message: HumanMessage,
|
||||
message_history: list[PreviousMessage],
|
||||
llm_config: LLMConfig,
|
||||
single_message_history: str | None = None,
|
||||
self, message_history: list[PreviousMessage], llm_config: LLMConfig
|
||||
) -> None:
|
||||
self.max_tokens = compute_max_llm_input_tokens(llm_config)
|
||||
|
||||
(
|
||||
self.message_history,
|
||||
self.history_token_cnts,
|
||||
) = translate_history_to_basemessages(message_history)
|
||||
|
||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
|
||||
self.user_message_and_token_cnt: tuple[HumanMessage, int] | None = None
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
provider_type=llm_config.model_provider,
|
||||
model_name=llm_config.model_name,
|
||||
@@ -70,24 +74,6 @@ class AnswerPromptBuilder:
|
||||
Callable[[str], list[int]], llm_tokenizer.encode
|
||||
)
|
||||
|
||||
self.raw_message_history = message_history
|
||||
(
|
||||
self.message_history,
|
||||
self.history_token_cnts,
|
||||
) = translate_history_to_basemessages(message_history)
|
||||
|
||||
# for cases where like the QA flow where we want to condense the chat history
|
||||
# into a single message rather than a sequence of User / Assistant messages
|
||||
self.single_message_history = single_message_history
|
||||
|
||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
|
||||
self.user_message_and_token_cnt = (
|
||||
user_message,
|
||||
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
|
||||
)
|
||||
|
||||
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
|
||||
|
||||
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
|
||||
if not system_message:
|
||||
self.system_message_and_token_cnt = None
|
||||
@@ -99,21 +85,18 @@ class AnswerPromptBuilder:
|
||||
)
|
||||
|
||||
def update_user_prompt(self, user_message: HumanMessage) -> None:
|
||||
if not user_message:
|
||||
self.user_message_and_token_cnt = None
|
||||
return
|
||||
|
||||
self.user_message_and_token_cnt = (
|
||||
user_message,
|
||||
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
|
||||
)
|
||||
|
||||
def append_message(self, message: BaseMessage) -> None:
|
||||
"""Append a new message to the message history."""
|
||||
token_count = check_message_tokens(message, self.llm_tokenizer_encode_func)
|
||||
self.new_messages_and_token_cnts.append((message, token_count))
|
||||
|
||||
def get_user_message_content(self) -> str:
|
||||
query, _ = message_to_prompt_and_imgs(self.user_message_and_token_cnt[0])
|
||||
return query
|
||||
|
||||
def build(self) -> list[BaseMessage]:
|
||||
def build(
|
||||
self, tool_call_summary: ToolCallSummary | None = None
|
||||
) -> list[BaseMessage]:
|
||||
if not self.user_message_and_token_cnt:
|
||||
raise ValueError("User message must be set before building prompt")
|
||||
|
||||
@@ -130,8 +113,25 @@ class AnswerPromptBuilder:
|
||||
|
||||
final_messages_with_tokens.append(self.user_message_and_token_cnt)
|
||||
|
||||
if self.new_messages_and_token_cnts:
|
||||
final_messages_with_tokens.extend(self.new_messages_and_token_cnts)
|
||||
if tool_call_summary:
|
||||
final_messages_with_tokens.append(
|
||||
(
|
||||
tool_call_summary.tool_call_request,
|
||||
check_message_tokens(
|
||||
tool_call_summary.tool_call_request,
|
||||
self.llm_tokenizer_encode_func,
|
||||
),
|
||||
)
|
||||
)
|
||||
final_messages_with_tokens.append(
|
||||
(
|
||||
tool_call_summary.tool_call_result,
|
||||
check_message_tokens(
|
||||
tool_call_summary.tool_call_result,
|
||||
self.llm_tokenizer_encode_func,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return drop_messages_history_overflow(
|
||||
final_messages_with_tokens, self.max_tokens
|
||||
|
||||
@@ -6,6 +6,7 @@ from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MA
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.persona import get_default_prompt__read_only
|
||||
from danswer.db.search_settings import get_multilingual_expansion
|
||||
from danswer.file_store.utils import InMemoryChatFile
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
@@ -13,7 +14,6 @@ from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
|
||||
from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT
|
||||
from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT
|
||||
@@ -132,9 +132,10 @@ def build_citations_system_message(
|
||||
|
||||
|
||||
def build_citations_user_message(
|
||||
message: HumanMessage,
|
||||
question: str,
|
||||
prompt_config: PromptConfig,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
files: list[InMemoryChatFile],
|
||||
all_doc_useful: bool,
|
||||
history_message: str = "",
|
||||
) -> HumanMessage:
|
||||
@@ -148,7 +149,6 @@ def build_citations_user_message(
|
||||
if history_message
|
||||
else ""
|
||||
)
|
||||
query, img_urls = message_to_prompt_and_imgs(message)
|
||||
|
||||
if context_docs:
|
||||
context_docs_str = build_complete_context_str(context_docs)
|
||||
@@ -158,22 +158,20 @@ def build_citations_user_message(
|
||||
optional_ignore_statement=optional_ignore,
|
||||
context_docs_str=context_docs_str,
|
||||
task_prompt=task_prompt_with_reminder,
|
||||
user_query=query,
|
||||
user_query=question,
|
||||
history_block=history_block,
|
||||
)
|
||||
else:
|
||||
# if no context docs provided, assume we're in the tool calling flow
|
||||
user_prompt = CITATIONS_PROMPT_FOR_TOOL_CALLING.format(
|
||||
task_prompt=task_prompt_with_reminder,
|
||||
user_query=query,
|
||||
user_query=question,
|
||||
history_block=history_block,
|
||||
)
|
||||
|
||||
user_prompt = user_prompt.strip()
|
||||
user_msg = HumanMessage(
|
||||
content=build_content_with_imgs(user_prompt, img_urls=img_urls)
|
||||
if img_urls
|
||||
else user_prompt
|
||||
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
|
||||
)
|
||||
|
||||
return user_msg
|
||||
|
||||
@@ -5,7 +5,6 @@ from danswer.configs.chat_configs import LANGUAGE_HINT
|
||||
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
|
||||
from danswer.db.search_settings import get_multilingual_expansion
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
|
||||
@@ -76,7 +75,7 @@ def _build_strong_llm_quotes_prompt(
|
||||
|
||||
|
||||
def build_quotes_user_message(
|
||||
message: HumanMessage,
|
||||
question: str,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
history_str: str,
|
||||
prompt: PromptConfig,
|
||||
@@ -87,10 +86,28 @@ def build_quotes_user_message(
|
||||
else _build_strong_llm_quotes_prompt
|
||||
)
|
||||
|
||||
query, _ = message_to_prompt_and_imgs(message)
|
||||
|
||||
return prompt_builder(
|
||||
question=query,
|
||||
question=question,
|
||||
context_docs=context_docs,
|
||||
history_str=history_str,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
def build_quotes_prompt(
|
||||
question: str,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
history_str: str,
|
||||
prompt: PromptConfig,
|
||||
) -> HumanMessage:
|
||||
prompt_builder = (
|
||||
_build_weak_llm_quotes_prompt
|
||||
if QA_PROMPT_OVERRIDE == "weak"
|
||||
else _build_strong_llm_quotes_prompt
|
||||
)
|
||||
|
||||
return prompt_builder(
|
||||
question=question,
|
||||
context_docs=context_docs,
|
||||
history_str=history_str,
|
||||
prompt=prompt,
|
||||
|
||||
@@ -19,7 +19,7 @@ from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.tools.tool_implementations.search.search_utils import section_to_dict
|
||||
from danswer.tools.search.search_utils import section_to_dict
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.llm.answering.llm_response_handler import ResponsePart
|
||||
from danswer.llm.answering.stream_processing.citation_processing import (
|
||||
CitationProcessor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
QuotesProcessor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
|
||||
|
||||
class AnswerResponseHandler(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DummyAnswerResponseHandler(AnswerResponseHandler):
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
# This is a dummy handler that returns nothing
|
||||
yield from []
|
||||
|
||||
|
||||
class CitationResponseHandler(AnswerResponseHandler):
|
||||
def __init__(
|
||||
self, context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.citation_processor = CitationProcessor(
|
||||
context_docs=self.context_docs,
|
||||
doc_id_to_rank_map=self.doc_id_to_rank_map,
|
||||
)
|
||||
self.processed_text = ""
|
||||
self.citations: list[CitationInfo] = []
|
||||
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
if response_item is None:
|
||||
return
|
||||
|
||||
content = (
|
||||
response_item.content if isinstance(response_item.content, str) else ""
|
||||
)
|
||||
|
||||
# Process the new content through the citation processor
|
||||
yield from self.citation_processor.process_token(content)
|
||||
|
||||
|
||||
class QuotesResponseHandler(AnswerResponseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
is_json_prompt: bool = True,
|
||||
):
|
||||
self.quotes_processor = QuotesProcessor(
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=is_json_prompt,
|
||||
)
|
||||
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
if response_item is None:
|
||||
yield from self.quotes_processor.process_token(None)
|
||||
return
|
||||
|
||||
content = (
|
||||
response_item.content if isinstance(response_item.content, str) else ""
|
||||
)
|
||||
|
||||
yield from self.quotes_processor.process_token(content)
|
||||
@@ -1,10 +1,12 @@
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
|
||||
from danswer.chat.models import AnswerQuestionStreamReturn
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import STOP_STREAM_PAT
|
||||
from danswer.llm.answering.models import StreamProcessor
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.prompts.constants import TRIPLE_BACKTICK
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -17,104 +19,128 @@ def in_code_block(llm_text: str) -> bool:
|
||||
return count % 2 != 0
|
||||
|
||||
|
||||
class CitationProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.stop_stream = stop_stream
|
||||
self.order_mapping = doc_id_to_rank_map.order_mapping
|
||||
self.llm_out = ""
|
||||
self.max_citation_num = len(context_docs)
|
||||
self.citation_order: list[int] = []
|
||||
self.curr_segment = ""
|
||||
self.cited_inds: set[int] = set()
|
||||
self.hold = ""
|
||||
self.current_citations: list[int] = []
|
||||
self.past_cite_count = 0
|
||||
def extract_citations_from_stream(
|
||||
tokens: Iterator[str],
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
|
||||
"""
|
||||
Key aspects:
|
||||
|
||||
def process_token(
|
||||
self, token: str | None
|
||||
) -> Generator[DanswerAnswerPiece | CitationInfo, None, None]:
|
||||
# None -> end of stream
|
||||
if token is None:
|
||||
yield DanswerAnswerPiece(answer_piece=self.curr_segment)
|
||||
return
|
||||
1. Stream Processing:
|
||||
- Processes tokens one by one, allowing for real-time handling of large texts.
|
||||
|
||||
if self.stop_stream:
|
||||
next_hold = self.hold + token
|
||||
if self.stop_stream in next_hold:
|
||||
return
|
||||
if next_hold == self.stop_stream[: len(next_hold)]:
|
||||
self.hold = next_hold
|
||||
return
|
||||
2. Citation Detection:
|
||||
- Uses regex to find citations in the format [number].
|
||||
- Example: [1], [2], etc.
|
||||
|
||||
3. Citation Mapping:
|
||||
- Maps detected citation numbers to actual document ranks using doc_id_to_rank_map.
|
||||
- Example: [1] might become [3] if doc_id_to_rank_map maps it to 3.
|
||||
|
||||
4. Citation Formatting:
|
||||
- Replaces citations with properly formatted versions.
|
||||
- Adds links if available: [[1]](https://example.com)
|
||||
- Handles cases where links are not available: [[1]]()
|
||||
|
||||
5. Duplicate Handling:
|
||||
- Skips consecutive citations of the same document to avoid redundancy.
|
||||
|
||||
6. Output Generation:
|
||||
- Yields DanswerAnswerPiece objects for regular text.
|
||||
- Yields CitationInfo objects for each unique citation encountered.
|
||||
|
||||
7. Context Awareness:
|
||||
- Uses context_docs to access document information for citations.
|
||||
|
||||
This function effectively processes a stream of text, identifies and reformats citations,
|
||||
and provides both the processed text and citation information as output.
|
||||
"""
|
||||
order_mapping = doc_id_to_rank_map.order_mapping
|
||||
llm_out = ""
|
||||
max_citation_num = len(context_docs)
|
||||
citation_order = []
|
||||
curr_segment = ""
|
||||
cited_inds = set()
|
||||
hold = ""
|
||||
|
||||
raw_out = ""
|
||||
current_citations: list[int] = []
|
||||
past_cite_count = 0
|
||||
for raw_token in tokens:
|
||||
raw_out += raw_token
|
||||
if stop_stream:
|
||||
next_hold = hold + raw_token
|
||||
if stop_stream in next_hold:
|
||||
break
|
||||
if next_hold == stop_stream[: len(next_hold)]:
|
||||
hold = next_hold
|
||||
continue
|
||||
token = next_hold
|
||||
self.hold = ""
|
||||
hold = ""
|
||||
else:
|
||||
token = raw_token
|
||||
|
||||
self.curr_segment += token
|
||||
self.llm_out += token
|
||||
curr_segment += token
|
||||
llm_out += token
|
||||
|
||||
# Handle code blocks without language tags
|
||||
if "`" in self.curr_segment:
|
||||
if self.curr_segment.endswith("`"):
|
||||
return
|
||||
elif "```" in self.curr_segment:
|
||||
piece_that_comes_after = self.curr_segment.split("```")[1][0]
|
||||
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
|
||||
self.curr_segment = self.curr_segment.replace("```", "```plaintext")
|
||||
if "`" in curr_segment:
|
||||
if curr_segment.endswith("`"):
|
||||
continue
|
||||
elif "```" in curr_segment:
|
||||
piece_that_comes_after = curr_segment.split("```")[1][0]
|
||||
if piece_that_comes_after == "\n" and in_code_block(llm_out):
|
||||
curr_segment = curr_segment.replace("```", "```plaintext")
|
||||
|
||||
citation_pattern = r"\[(\d+)\]"
|
||||
citations_found = list(re.finditer(citation_pattern, self.curr_segment))
|
||||
|
||||
citations_found = list(re.finditer(citation_pattern, curr_segment))
|
||||
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
|
||||
possible_citation_found = re.search(
|
||||
possible_citation_pattern, self.curr_segment
|
||||
)
|
||||
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
|
||||
|
||||
if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5:
|
||||
self.current_citations = []
|
||||
# `past_cite_count`: number of characters since past citation
|
||||
# 5 to ensure a citation hasn't occured
|
||||
if len(citations_found) == 0 and len(llm_out) - past_cite_count > 5:
|
||||
current_citations = []
|
||||
|
||||
result = "" # Initialize result here
|
||||
if citations_found and not in_code_block(self.llm_out):
|
||||
if citations_found and not in_code_block(llm_out):
|
||||
last_citation_end = 0
|
||||
length_to_add = 0
|
||||
while len(citations_found) > 0:
|
||||
citation = citations_found.pop(0)
|
||||
numerical_value = int(citation.group(1))
|
||||
|
||||
if 1 <= numerical_value <= self.max_citation_num:
|
||||
context_llm_doc = self.context_docs[numerical_value - 1]
|
||||
real_citation_num = self.order_mapping[context_llm_doc.document_id]
|
||||
if 1 <= numerical_value <= max_citation_num:
|
||||
context_llm_doc = context_docs[numerical_value - 1]
|
||||
real_citation_num = order_mapping[context_llm_doc.document_id]
|
||||
|
||||
if real_citation_num not in self.citation_order:
|
||||
self.citation_order.append(real_citation_num)
|
||||
if real_citation_num not in citation_order:
|
||||
citation_order.append(real_citation_num)
|
||||
|
||||
target_citation_num = (
|
||||
self.citation_order.index(real_citation_num) + 1
|
||||
)
|
||||
target_citation_num = citation_order.index(real_citation_num) + 1
|
||||
|
||||
# Skip consecutive citations of the same work
|
||||
if target_citation_num in self.current_citations:
|
||||
if target_citation_num in current_citations:
|
||||
start, end = citation.span()
|
||||
real_start = length_to_add + start
|
||||
diff = end - start
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: length_to_add + start]
|
||||
+ self.curr_segment[real_start + diff :]
|
||||
curr_segment = (
|
||||
curr_segment[: length_to_add + start]
|
||||
+ curr_segment[real_start + diff :]
|
||||
)
|
||||
length_to_add -= diff
|
||||
continue
|
||||
|
||||
# Handle edge case where LLM outputs citation itself
|
||||
if self.curr_segment.startswith("[["):
|
||||
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
|
||||
# by allowing it to generate citations on its own.
|
||||
if curr_segment.startswith("[["):
|
||||
match = re.match(r"\[\[(\d+)\]\]", curr_segment)
|
||||
if match:
|
||||
try:
|
||||
doc_id = int(match.group(1))
|
||||
context_llm_doc = self.context_docs[doc_id - 1]
|
||||
context_llm_doc = context_docs[doc_id - 1]
|
||||
yield CitationInfo(
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
@@ -124,57 +150,75 @@ class CitationProcessor:
|
||||
f"Manual LLM citation didn't properly cite documents {e}"
|
||||
)
|
||||
else:
|
||||
# Will continue attempt on next loops
|
||||
logger.warning(
|
||||
"Manual LLM citation wasn't able to close brackets"
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
link = context_llm_doc.link
|
||||
|
||||
# Replace the citation in the current segment
|
||||
start, end = citation.span()
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
curr_segment = (
|
||||
curr_segment[: start + length_to_add]
|
||||
+ f"[{target_citation_num}]"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
+ curr_segment[end + length_to_add :]
|
||||
)
|
||||
|
||||
self.past_cite_count = len(self.llm_out)
|
||||
self.current_citations.append(target_citation_num)
|
||||
past_cite_count = len(llm_out)
|
||||
current_citations.append(target_citation_num)
|
||||
|
||||
if target_citation_num not in self.cited_inds:
|
||||
self.cited_inds.add(target_citation_num)
|
||||
if target_citation_num not in cited_inds:
|
||||
cited_inds.add(target_citation_num)
|
||||
yield CitationInfo(
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
if link:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
prev_length = len(curr_segment)
|
||||
curr_segment = (
|
||||
curr_segment[: start + length_to_add]
|
||||
+ f"[[{target_citation_num}]]({link})"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
+ curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
length_to_add += len(curr_segment) - prev_length
|
||||
|
||||
else:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
prev_length = len(curr_segment)
|
||||
curr_segment = (
|
||||
curr_segment[: start + length_to_add]
|
||||
+ f"[[{target_citation_num}]]()"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
+ curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
length_to_add += len(curr_segment) - prev_length
|
||||
|
||||
last_citation_end = end + length_to_add
|
||||
|
||||
if last_citation_end > 0:
|
||||
result += self.curr_segment[:last_citation_end]
|
||||
self.curr_segment = self.curr_segment[last_citation_end:]
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment[:last_citation_end])
|
||||
curr_segment = curr_segment[last_citation_end:]
|
||||
if possible_citation_found:
|
||||
continue
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||
curr_segment = ""
|
||||
|
||||
if not possible_citation_found:
|
||||
result += self.curr_segment
|
||||
self.curr_segment = ""
|
||||
if curr_segment:
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||
|
||||
if result:
|
||||
yield DanswerAnswerPiece(answer_piece=result)
|
||||
|
||||
def build_citation_processor(
|
||||
context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
|
||||
) -> StreamProcessor:
|
||||
def stream_processor(
|
||||
tokens: Iterator[str],
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
yield from extract_citations_from_stream(
|
||||
tokens=tokens,
|
||||
context_docs=context_docs,
|
||||
doc_id_to_rank_map=doc_id_to_rank_map,
|
||||
)
|
||||
|
||||
return stream_processor
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
import regex
|
||||
|
||||
from danswer.chat.models import AnswerQuestionStreamReturn
|
||||
from danswer.chat.models import DanswerAnswer
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerQuote
|
||||
@@ -154,7 +157,7 @@ def separate_answer_quotes(
|
||||
return _extract_answer_quotes_freeform(clean_up_code_blocks(answer_raw))
|
||||
|
||||
|
||||
def _process_answer(
|
||||
def process_answer(
|
||||
answer_raw: str,
|
||||
docs: list[LlmDoc],
|
||||
is_json_prompt: bool = True,
|
||||
@@ -192,7 +195,7 @@ def _stream_json_answer_end(answer_so_far: str, next_token: str) -> bool:
|
||||
def _extract_quotes_from_completed_token_stream(
|
||||
model_output: str, context_docs: list[LlmDoc], is_json_prompt: bool = True
|
||||
) -> DanswerQuotes:
|
||||
answer, quotes = _process_answer(model_output, context_docs, is_json_prompt)
|
||||
answer, quotes = process_answer(model_output, context_docs, is_json_prompt)
|
||||
if answer:
|
||||
logger.notice(answer)
|
||||
elif model_output:
|
||||
@@ -201,101 +204,94 @@ def _extract_quotes_from_completed_token_stream(
|
||||
return quotes
|
||||
|
||||
|
||||
class QuotesProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
is_json_prompt: bool = True,
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.is_json_prompt = is_json_prompt
|
||||
def process_model_tokens(
|
||||
tokens: Iterator[str],
|
||||
context_docs: list[LlmDoc],
|
||||
is_json_prompt: bool = True,
|
||||
) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]:
|
||||
"""Used in the streaming case to process the model output
|
||||
into an Answer and Quotes
|
||||
|
||||
self.found_answer_start = False if is_json_prompt else True
|
||||
self.found_answer_end = False
|
||||
self.hold_quote = ""
|
||||
self.model_output = ""
|
||||
self.hold = ""
|
||||
Yields Answer tokens back out in a dict for streaming to frontend
|
||||
When Answer section ends, yields dict with answer_finished key
|
||||
Collects all the tokens at the end to form the complete model output"""
|
||||
quote_pat = f"\n{QUOTE_PAT}"
|
||||
# Sometimes worse model outputs new line instead of :
|
||||
quote_loose = f"\n{quote_pat[:-1]}\n"
|
||||
# Sometime model outputs two newlines before quote section
|
||||
quote_pat_full = f"\n{quote_pat}"
|
||||
model_output: str = ""
|
||||
found_answer_start = False if is_json_prompt else True
|
||||
found_answer_end = False
|
||||
hold_quote = ""
|
||||
|
||||
def process_token(
|
||||
self, token: str | None
|
||||
) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]:
|
||||
# None -> end of stream
|
||||
if token is None:
|
||||
if self.model_output:
|
||||
yield _extract_quotes_from_completed_token_stream(
|
||||
model_output=self.model_output,
|
||||
context_docs=self.context_docs,
|
||||
is_json_prompt=self.is_json_prompt,
|
||||
)
|
||||
return
|
||||
for token in tokens:
|
||||
model_previous = model_output
|
||||
model_output += token
|
||||
|
||||
model_previous = self.model_output
|
||||
self.model_output += token
|
||||
|
||||
if not self.found_answer_start:
|
||||
m = answer_pattern.search(self.model_output)
|
||||
if not found_answer_start:
|
||||
m = answer_pattern.search(model_output)
|
||||
if m:
|
||||
self.found_answer_start = True
|
||||
found_answer_start = True
|
||||
|
||||
# Prevent heavy cases of hallucinations
|
||||
if self.is_json_prompt and len(self.model_output) > 70:
|
||||
# Prevent heavy cases of hallucinations where model is never providing a JSON
|
||||
# We want to quickly update the user - not stream forever
|
||||
if is_json_prompt and len(model_output) > 70:
|
||||
logger.warning("LLM did not produce json as prompted")
|
||||
self.found_answer_end = True
|
||||
return
|
||||
found_answer_end = True
|
||||
continue
|
||||
|
||||
remaining = self.model_output[m.end() :]
|
||||
|
||||
# Look for an unescaped quote, which means the answer is entirely contained
|
||||
# in this token e.g. if the token is `{"answer": "blah", "qu`
|
||||
quote_indices = [i for i, char in enumerate(remaining) if char == '"']
|
||||
for quote_idx in quote_indices:
|
||||
# Check if quote is escaped by counting backslashes before it
|
||||
num_backslashes = 0
|
||||
pos = quote_idx - 1
|
||||
while pos >= 0 and remaining[pos] == "\\":
|
||||
num_backslashes += 1
|
||||
pos -= 1
|
||||
# If even number of backslashes, quote is not escaped
|
||||
if num_backslashes % 2 == 0:
|
||||
yield DanswerAnswerPiece(answer_piece=remaining[:quote_idx])
|
||||
return
|
||||
|
||||
# If no unescaped quote found, yield the remaining string
|
||||
remaining = model_output[m.end() :]
|
||||
if len(remaining) > 0:
|
||||
yield DanswerAnswerPiece(answer_piece=remaining)
|
||||
return
|
||||
continue
|
||||
|
||||
if self.found_answer_start and not self.found_answer_end:
|
||||
if self.is_json_prompt and _stream_json_answer_end(model_previous, token):
|
||||
self.found_answer_end = True
|
||||
if found_answer_start and not found_answer_end:
|
||||
if is_json_prompt and _stream_json_answer_end(model_previous, token):
|
||||
found_answer_end = True
|
||||
|
||||
# return the remaining part of the answer e.g. token might be 'd.", ' and we should yield 'd.'
|
||||
if token:
|
||||
try:
|
||||
answer_token_section = token.index('"')
|
||||
yield DanswerAnswerPiece(
|
||||
answer_piece=self.hold_quote + token[:answer_token_section]
|
||||
answer_piece=hold_quote + token[:answer_token_section]
|
||||
)
|
||||
except ValueError:
|
||||
logger.error("Quotation mark not found in token")
|
||||
yield DanswerAnswerPiece(answer_piece=self.hold_quote + token)
|
||||
yield DanswerAnswerPiece(answer_piece=hold_quote + token)
|
||||
yield DanswerAnswerPiece(answer_piece=None)
|
||||
return
|
||||
|
||||
elif not self.is_json_prompt:
|
||||
quote_pat = f"\n{QUOTE_PAT}"
|
||||
quote_loose = f"\n{quote_pat[:-1]}\n"
|
||||
quote_pat_full = f"\n{quote_pat}"
|
||||
|
||||
if (
|
||||
quote_pat in self.hold_quote + token
|
||||
or quote_loose in self.hold_quote + token
|
||||
):
|
||||
self.found_answer_end = True
|
||||
continue
|
||||
elif not is_json_prompt:
|
||||
if quote_pat in hold_quote + token or quote_loose in hold_quote + token:
|
||||
found_answer_end = True
|
||||
yield DanswerAnswerPiece(answer_piece=None)
|
||||
return
|
||||
if self.hold_quote + token in quote_pat_full:
|
||||
self.hold_quote += token
|
||||
return
|
||||
continue
|
||||
if hold_quote + token in quote_pat_full:
|
||||
hold_quote += token
|
||||
continue
|
||||
yield DanswerAnswerPiece(answer_piece=hold_quote + token)
|
||||
hold_quote = ""
|
||||
|
||||
yield DanswerAnswerPiece(answer_piece=self.hold_quote + token)
|
||||
self.hold_quote = ""
|
||||
logger.debug(f"Raw Model QnA Output: {model_output}")
|
||||
|
||||
yield _extract_quotes_from_completed_token_stream(
|
||||
model_output=model_output,
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=is_json_prompt,
|
||||
)
|
||||
|
||||
|
||||
def build_quotes_processor(
|
||||
context_docs: list[LlmDoc], is_json_prompt: bool
|
||||
) -> Callable[[Iterator[str]], AnswerQuestionStreamReturn]:
|
||||
def stream_processor(
|
||||
tokens: Iterator[str],
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
yield from process_model_tokens(
|
||||
tokens=tokens,
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=is_json_prompt,
|
||||
)
|
||||
|
||||
return stream_processor
|
||||
|
||||
@@ -1,207 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
|
||||
from danswer.llm.answering.llm_response_handler import LLMCall
|
||||
from danswer.llm.answering.llm_response_handler import ResponsePart
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.message import build_tool_message
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool_runner import (
|
||||
check_which_tools_should_run_for_non_tool_calling_llm,
|
||||
)
|
||||
from danswer.tools.tool_runner import ToolRunner
|
||||
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ToolResponseHandler:
|
||||
def __init__(self, tools: list[Tool]):
|
||||
self.tools = tools
|
||||
|
||||
self.tool_call_chunk: AIMessageChunk | None = None
|
||||
self.tool_call_requests: list[ToolCall] = []
|
||||
|
||||
self.tool_runner: ToolRunner | None = None
|
||||
self.tool_call_summary: ToolCallSummary | None = None
|
||||
|
||||
self.tool_kickoff: ToolCallKickoff | None = None
|
||||
self.tool_responses: list[ToolResponse] = []
|
||||
self.tool_final_result: ToolCallFinalResult | None = None
|
||||
|
||||
@classmethod
|
||||
def get_tool_call_for_non_tool_calling_llm(
|
||||
cls, llm_call: LLMCall, llm: LLM
|
||||
) -> tuple[Tool, dict] | None:
|
||||
if llm_call.force_use_tool.force_use:
|
||||
# if we are forcing a tool, we don't need to check which tools to run
|
||||
tool = next(
|
||||
(
|
||||
t
|
||||
for t in llm_call.tools
|
||||
if t.name == llm_call.force_use_tool.tool_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not tool:
|
||||
raise RuntimeError(
|
||||
f"Tool '{llm_call.force_use_tool.tool_name}' not found"
|
||||
)
|
||||
|
||||
tool_args = (
|
||||
llm_call.force_use_tool.args
|
||||
if llm_call.force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=llm_call.prompt_builder.get_user_message_content(),
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
force_run=True,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise RuntimeError(f"Tool '{tool.name}' did not return args")
|
||||
|
||||
return (tool, tool_args)
|
||||
else:
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools=llm_call.tools,
|
||||
query=llm_call.prompt_builder.get_user_message_content(),
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
available_tools_and_args = [
|
||||
(llm_call.tools[ind], args)
|
||||
for ind, args in enumerate(tool_options)
|
||||
if args is not None
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
|
||||
)
|
||||
|
||||
chosen_tool_and_args = (
|
||||
select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args=available_tools_and_args,
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
query=llm_call.prompt_builder.get_user_message_content(),
|
||||
llm=llm,
|
||||
)
|
||||
if available_tools_and_args
|
||||
else None
|
||||
)
|
||||
|
||||
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
|
||||
return chosen_tool_and_args
|
||||
|
||||
def _handle_tool_call(self) -> Generator[ResponsePart, None, None]:
|
||||
if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls:
|
||||
return
|
||||
|
||||
self.tool_call_requests = self.tool_call_chunk.tool_calls
|
||||
|
||||
selected_tool: Tool | None = None
|
||||
selected_tool_call_request: ToolCall | None = None
|
||||
for tool_call_request in self.tool_call_requests:
|
||||
known_tools_by_name = [
|
||||
tool for tool in self.tools if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
if not known_tools_by_name:
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"self.tools: {self.tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
selected_tool = known_tools_by_name[0]
|
||||
selected_tool_call_request = tool_call_request
|
||||
|
||||
if selected_tool and selected_tool_call_request:
|
||||
break
|
||||
|
||||
if not selected_tool or not selected_tool_call_request:
|
||||
return
|
||||
|
||||
logger.info(f"Selected tool: {selected_tool.name}")
|
||||
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
|
||||
self.tool_runner = ToolRunner(selected_tool, selected_tool_call_request["args"])
|
||||
self.tool_kickoff = self.tool_runner.kickoff()
|
||||
yield self.tool_kickoff
|
||||
|
||||
for response in self.tool_runner.tool_responses():
|
||||
self.tool_responses.append(response)
|
||||
yield response
|
||||
|
||||
self.tool_final_result = self.tool_runner.tool_final_result()
|
||||
yield self.tool_final_result
|
||||
|
||||
self.tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=self.tool_call_chunk,
|
||||
tool_call_result=build_tool_message(
|
||||
selected_tool_call_request, self.tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
if response_item is None:
|
||||
yield from self._handle_tool_call()
|
||||
|
||||
if isinstance(response_item, AIMessageChunk) and (
|
||||
response_item.tool_call_chunks or response_item.tool_calls
|
||||
):
|
||||
if self.tool_call_chunk is None:
|
||||
self.tool_call_chunk = response_item
|
||||
else:
|
||||
self.tool_call_chunk += response_item # type: ignore
|
||||
|
||||
return
|
||||
|
||||
def next_llm_call(self, current_llm_call: LLMCall) -> LLMCall | None:
|
||||
if (
|
||||
self.tool_runner is None
|
||||
or self.tool_call_summary is None
|
||||
or self.tool_kickoff is None
|
||||
or self.tool_final_result is None
|
||||
):
|
||||
return None
|
||||
|
||||
tool_runner = self.tool_runner
|
||||
new_prompt_builder = tool_runner.tool.build_next_prompt(
|
||||
prompt_builder=current_llm_call.prompt_builder,
|
||||
tool_call_summary=self.tool_call_summary,
|
||||
tool_responses=self.tool_responses,
|
||||
using_tool_calling_llm=current_llm_call.using_tool_calling_llm,
|
||||
)
|
||||
return LLMCall(
|
||||
prompt_builder=new_prompt_builder,
|
||||
tools=[], # for now, only allow one tool call per response
|
||||
force_use_tool=ForceUseTool(
|
||||
force_use=False,
|
||||
tool_name="",
|
||||
args=None,
|
||||
),
|
||||
files=current_llm_call.files,
|
||||
using_tool_calling_llm=current_llm_call.using_tool_calling_llm,
|
||||
tool_call_info=[
|
||||
self.tool_kickoff,
|
||||
*self.tool_responses,
|
||||
self.tool_final_result,
|
||||
],
|
||||
)
|
||||
@@ -83,10 +83,8 @@ def _convert_litellm_message_to_langchain_message(
|
||||
"args": json.loads(tool_call.function.arguments),
|
||||
"id": tool_call.id,
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
if tool_calls
|
||||
else [],
|
||||
for tool_call in (tool_calls if tool_calls else [])
|
||||
],
|
||||
)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=content)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user