Compare commits

..

1 Commits

Author SHA1 Message Date
pablodanswer
81e1ac9183 remove empty directory 2024-10-28 16:10:34 -07:00
562 changed files with 20212 additions and 26971 deletions

View File

@@ -3,61 +3,61 @@ name: Build and Push Backend Image on Tag
on:
push:
tags:
- "*"
- '*'
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'danswer/danswer-backend-cloud' || 'danswer/danswer-backend' }}
REGISTRY_IMAGE: danswer/danswer-backend
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
# TODO: investigate a matrix build like the web container
# TODO: investigate a matrix build like the web container
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Install build-essential
run: |
sudo apt-get update
sudo apt-get install -y build-essential
- name: Install build-essential
run: |
sudo apt-get update
sudo apt-get install -y build-essential
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
with:
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
trivyignores: ./backend/.trivyignore
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'
trivyignores: ./backend/.trivyignore

View File

@@ -4,12 +4,12 @@ name: Build and Push Cloud Web Image on Tag
on:
push:
tags:
- "*"
- '*'
env:
REGISTRY_IMAGE: danswer/danswer-web-server-cloud
REGISTRY_IMAGE: danswer/danswer-cloud-web-server
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build:
runs-on:
@@ -28,11 +28,11 @@ jobs:
- name: Prepare
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
@@ -41,16 +41,16 @@ jobs:
tags: |
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push by digest
id: build
uses: docker/build-push-action@v5
@@ -65,17 +65,17 @@ jobs:
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
# needed due to weird interactions with the builds for different platforms
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
- name: Export digest
run: |
mkdir -p /tmp/digests
digest="${{ steps.build.outputs.digest }}"
touch "/tmp/digests/${digest#sha256:}"
touch "/tmp/digests/${digest#sha256:}"
- name: Upload digest
uses: actions/upload-artifact@v4
with:
@@ -95,42 +95,42 @@ jobs:
path: /tmp/digests
pattern: digests-*
merge-multiple: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Create manifest list and push
working-directory: /tmp/digests
run: |
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
- name: Inspect image
run: |
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
severity: 'CRITICAL,HIGH'

View File

@@ -3,53 +3,53 @@ name: Build and Push Model Server Image on Tag
on:
push:
tags:
- "*"
- '*'
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'danswer/danswer-model-server-cloud' || 'danswer/danswer-model-server' }}
REGISTRY_IMAGE: danswer/danswer-model-server
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Model Server Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
- name: Model Server Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
with:
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'

View File

@@ -1,76 +0,0 @@
# Scan for problematic software licenses
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
name: 'Nightly - Scan licenses'
on:
# schedule:
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
workflow_dispatch: # Allows manual triggering
permissions:
actions: read
contents: read
security-events: write
jobs:
scan-licenses:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- name: Get explicit and transitive dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
pip freeze > requirements-all.txt
- name: Check python
id: license_check_report
uses: pilosus/action-pip-license-checker@v2
with:
requirements: 'requirements-all.txt'
fail: 'Copyleft'
exclude: '(?i)^(pylint|aio[-_]*).*'
- name: Print report
if: ${{ always() }}
run: echo "${{ steps.license_check_report.outputs.report }}"
- name: Install npm dependencies
working-directory: ./web
run: npm ci
- name: Run Trivy vulnerability scanner in repo mode
uses: aquasecurity/trivy-action@0.28.0
with:
scan-type: fs
scanners: license
format: table
# format: sarif
# output: trivy-results.sarif
severity: HIGH,CRITICAL
# - name: Upload Trivy scan results to GitHub Security tab
# uses: github/codeql-action/upload-sarif@v3
# with:
# sarif_file: trivy-results.sarif

View File

@@ -197,8 +197,7 @@ jobs:
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
danswer/danswer-integration:test \
/app/tests/integration/tests \
/app/tests/integration/connector_job_tests
/app/tests/integration/tests
continue-on-error: true
id: run_tests
@@ -211,18 +210,17 @@ jobs:
echo "All integration tests passed successfully."
fi
# save before stopping the containers so the logs can be captured
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
- name: Save Docker logs
if: success() || failure()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
- name: Upload logs
if: success() || failure()

View File

@@ -1,72 +0,0 @@
name: Helm - Lint and Test Charts
on:
merge_group:
pull_request:
branches: [ main ]
workflow_dispatch: # Allows manual triggering
jobs:
helm-chart-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
# fetch-depth 0 is required for helm/chart-testing-action
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Helm
uses: azure/setup-helm@v4.2.0
with:
version: v3.14.4
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.6.1
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
- name: Run chart-testing (list-changed)
id: list-changed
run: |
echo "default_branch: ${{ github.event.repository.default_branch }}"
changed=$(ct list-changed --remote origin --target-branch ${{ github.event.repository.default_branch }} --chart-dirs deployment/helm/charts)
echo "list-changed output: $changed"
if [[ -n "$changed" ]]; then
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
# rkuo: I don't think we need python?
# - name: Set up Python
# uses: actions/setup-python@v5
# with:
# python-version: '3.11'
# cache: 'pip'
# cache-dependency-path: |
# backend/requirements/default.txt
# backend/requirements/dev.txt
# backend/requirements/model_server.txt
# - run: |
# python -m pip install --upgrade pip
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
# lint all charts if any changes were detected
- name: Run chart-testing (lint)
if: steps.list-changed.outputs.changed == 'true'
run: ct lint --config ct.yaml --all
# the following would lint only changed charts, but linting isn't expensive
# run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
- name: Create kind cluster
if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@v1.10.0
- name: Run chart-testing (install)
if: steps.list-changed.outputs.changed == 'true'
run: ct install --all --helm-extra-set-args="--set=nginx.enabled=false" --debug --config ct.yaml
# the following would install only changed charts, but we only have one chart so
# don't worry about that for now
# run: ct install --target-branch ${{ github.event.repository.default_branch }}

View File

@@ -0,0 +1,68 @@
# This workflow is intentionally disabled while we're still working on it
# It's close to ready, but a race condition needs to be fixed with
# API server and Vespa startup, and it needs to have a way to build/test against
# local containers
name: Helm - Lint and Test Charts
on:
merge_group:
pull_request:
branches: [ main ]
jobs:
lint-test:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
# fetch-depth 0 is required for helm/chart-testing-action
steps:
- name: Checkout code
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Set up Helm
uses: azure/setup-helm@v4.2.0
with:
version: v3.14.4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.6.1
- name: Run chart-testing (list-changed)
id: list-changed
run: |
changed=$(ct list-changed --target-branch ${{ github.event.repository.default_branch }})
if [[ -n "$changed" ]]; then
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
- name: Run chart-testing (lint)
# if: steps.list-changed.outputs.changed == 'true'
run: ct lint --all --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
- name: Create kind cluster
# if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@v1.10.0
- name: Run chart-testing (install)
# if: steps.list-changed.outputs.changed == 'true'
run: ct install --all --config ct.yaml
# run: ct install --target-branch ${{ github.event.repository.default_branch }}

View File

@@ -18,11 +18,6 @@ env:
# Jira
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
# Google
GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
jobs:
connectors-check:

View File

@@ -15,7 +15,7 @@ env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
jobs:
model-check:
connectors-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]

View File

@@ -203,7 +203,7 @@
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
"vespa_metadata_sync,connector_deletion",
],
"presentation": {
"group": "2",
@@ -232,7 +232,7 @@
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
"connector_pruning",
],
"presentation": {
"group": "2",

View File

@@ -1,5 +1,4 @@
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/README.md"} -->
<a name="readme-top"></a>
<h2 align="center">
<a href="https://www.danswer.ai/"> <img width="50%" src="https://github.com/danswer-owners/danswer/blob/1fabd9372d66cd54238847197c33f091a724803b/DanswerWithName.png?raw=true)" /></a>
@@ -128,19 +127,3 @@ To try the Danswer Enterprise Edition:
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
## ⭐Star History
[![Star History Chart](https://api.star-history.com/svg?repos=danswer-ai/danswer&type=Date)](https://star-history.com/#danswer-ai/danswer&Date)
## ✨Contributors
<a href="https://github.com/aryn-ai/sycamore/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=danswer-ai/danswer"/>
</a>
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
↑ Back to Top ↑
</a>
</p>

View File

@@ -12,6 +12,7 @@ ARG DANSWER_VERSION=0.8-dev
ENV DANSWER_VERSION=${DANSWER_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true"
ARG CA_CERT_CONTENT=""
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
# Install system dependencies
@@ -38,6 +39,15 @@ RUN apt-get update && \
apt-get clean
# Conditionally write the CA certificate and update certificates
RUN if [ -n "$CA_CERT_CONTENT" ]; then \
echo "Adding custom CA certificate"; \
echo "$CA_CERT_CONTENT" > /usr/local/share/ca-certificates/my-ca.crt && \
chmod 644 /usr/local/share/ca-certificates/my-ca.crt && \
update-ca-certificates; \
else \
echo "No custom CA certificate provided"; \
fi
# Install Python dependencies
# Remove py which is pulled in by retry, py is not needed and is a CVE
@@ -77,6 +87,7 @@ RUN apt-get update && \
RUN python -c "from tokenizers import Tokenizer; \
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
# Pre-downloading NLTK for setups with limited egress
RUN python -c "import nltk; \
nltk.download('stopwords', quiet=True); \

View File

@@ -1,68 +0,0 @@
"""default chosen assistants to none
Revision ID: 26b931506ecb
Revises: 2daa494a0851
Create Date: 2024-11-12 13:23:29.858995
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "26b931506ecb"
down_revision = "2daa494a0851"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user", sa.Column("chosen_assistants_new", postgresql.JSONB(), nullable=True)
)
op.execute(
"""
UPDATE "user"
SET chosen_assistants_new =
CASE
WHEN chosen_assistants = '[-2, -1, 0]' THEN NULL
ELSE chosen_assistants
END
"""
)
op.drop_column("user", "chosen_assistants")
op.alter_column(
"user", "chosen_assistants_new", new_column_name="chosen_assistants"
)
def downgrade() -> None:
op.add_column(
"user",
sa.Column(
"chosen_assistants_old",
postgresql.JSONB(),
nullable=False,
server_default="[-2, -1, 0]",
),
)
op.execute(
"""
UPDATE "user"
SET chosen_assistants_old =
CASE
WHEN chosen_assistants IS NULL THEN '[-2, -1, 0]'::jsonb
ELSE chosen_assistants
END
"""
)
op.drop_column("user", "chosen_assistants")
op.alter_column(
"user", "chosen_assistants_old", new_column_name="chosen_assistants"
)

View File

@@ -1,30 +0,0 @@
"""add-group-sync-time
Revision ID: 2daa494a0851
Revises: c0fd6e4da83a
Create Date: 2024-11-11 10:57:22.991157
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2daa494a0851"
down_revision = "c0fd6e4da83a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"last_time_external_group_sync",
sa.DateTime(timezone=True),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "last_time_external_group_sync")

View File

@@ -1,50 +0,0 @@
"""single tool call per message
Revision ID: 33cb72ea4d80
Revises: 5b29123cd710
Create Date: 2024-11-01 12:51:01.535003
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "33cb72ea4d80"
down_revision = "5b29123cd710"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Step 1: Delete extraneous ToolCall entries
# Keep only the ToolCall with the smallest 'id' for each 'message_id'
op.execute(
sa.text(
"""
DELETE FROM tool_call
WHERE id NOT IN (
SELECT MIN(id)
FROM tool_call
WHERE message_id IS NOT NULL
GROUP BY message_id
);
"""
)
)
# Step 2: Add a unique constraint on message_id
op.create_unique_constraint(
constraint_name="uq_tool_call_message_id",
table_name="tool_call",
columns=["message_id"],
)
def downgrade() -> None:
# Step 1: Drop the unique constraint on message_id
op.drop_constraint(
constraint_name="uq_tool_call_message_id",
table_name="tool_call",
type_="unique",
)

View File

@@ -1,70 +0,0 @@
"""nullable search settings for historic index attempts
Revision ID: 5b29123cd710
Revises: 949b4a92a401
Create Date: 2024-10-30 19:37:59.630704
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "5b29123cd710"
down_revision = "949b4a92a401"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop the existing foreign key constraint
op.drop_constraint(
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
)
# Modify the column to be nullable
op.alter_column(
"index_attempt", "search_settings_id", existing_type=sa.INTEGER(), nullable=True
)
# Add back the foreign key with ON DELETE SET NULL
op.create_foreign_key(
"fk_index_attempt_search_settings",
"index_attempt",
"search_settings",
["search_settings_id"],
["id"],
ondelete="SET NULL",
)
def downgrade() -> None:
# Warning: This will delete all index attempts that don't have search settings
op.execute(
"""
DELETE FROM index_attempt
WHERE search_settings_id IS NULL
"""
)
# Drop foreign key constraint
op.drop_constraint(
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
)
# Modify the column to be not nullable
op.alter_column(
"index_attempt",
"search_settings_id",
existing_type=sa.INTEGER(),
nullable=False,
)
# Add back the foreign key without ON DELETE SET NULL
op.create_foreign_key(
"fk_index_attempt_search_settings",
"index_attempt",
"search_settings",
["search_settings_id"],
["id"],
)

View File

@@ -1,30 +0,0 @@
"""add creator to cc pair
Revision ID: 9cf5c00f72fe
Revises: c0fd6e4da83a
Create Date: 2024-11-12 15:16:42.682902
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9cf5c00f72fe"
down_revision = "26b931506ecb"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"creator_id",
sa.UUID(as_uuid=True),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "creator_id")

View File

@@ -288,15 +288,6 @@ def upgrade() -> None:
def downgrade() -> None:
# NOTE: you will lose all chat history. This is to satisfy the non-nullable constraints
# below
op.execute("DELETE FROM chat_feedback")
op.execute("DELETE FROM chat_message__search_doc")
op.execute("DELETE FROM document_retrieval_feedback")
op.execute("DELETE FROM document_retrieval_feedback")
op.execute("DELETE FROM chat_message")
op.execute("DELETE FROM chat_session")
op.drop_constraint(
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
)

View File

@@ -1,48 +0,0 @@
"""remove description from starter messages
Revision ID: b72ed7a5db0e
Revises: 33cb72ea4d80
Create Date: 2024-11-03 15:55:28.944408
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "b72ed7a5db0e"
down_revision = "33cb72ea4d80"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
sa.text(
"""
UPDATE persona
SET starter_messages = (
SELECT jsonb_agg(elem - 'description')
FROM jsonb_array_elements(starter_messages) elem
)
WHERE starter_messages IS NOT NULL
AND jsonb_typeof(starter_messages) = 'array'
"""
)
)
def downgrade() -> None:
op.execute(
sa.text(
"""
UPDATE persona
SET starter_messages = (
SELECT jsonb_agg(elem || '{"description": ""}')
FROM jsonb_array_elements(starter_messages) elem
)
WHERE starter_messages IS NOT NULL
AND jsonb_typeof(starter_messages) = 'array'
"""
)
)

View File

@@ -1,29 +0,0 @@
"""add recent assistants
Revision ID: c0fd6e4da83a
Revises: b72ed7a5db0e
Create Date: 2024-11-03 17:28:54.916618
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "c0fd6e4da83a"
down_revision = "b72ed7a5db0e"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column(
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
),
)
def downgrade() -> None:
op.drop_column("user", "recent_assistants")

View File

@@ -23,56 +23,6 @@ def upgrade() -> None:
def downgrade() -> None:
# Delete chat messages and feedback first since they reference chat sessions
# Get chat messages from sessions with null persona_id
chat_messages_query = """
SELECT id
FROM chat_message
WHERE chat_session_id IN (
SELECT id
FROM chat_session
WHERE persona_id IS NULL
)
"""
# Delete dependent records first
op.execute(
f"""
DELETE FROM document_retrieval_feedback
WHERE chat_message_id IN (
{chat_messages_query}
)
"""
)
op.execute(
f"""
DELETE FROM chat_message__search_doc
WHERE chat_message_id IN (
{chat_messages_query}
)
"""
)
# Delete chat messages
op.execute(
"""
DELETE FROM chat_message
WHERE chat_session_id IN (
SELECT id
FROM chat_session
WHERE persona_id IS NULL
)
"""
)
# Now we can safely delete the chat sessions
op.execute(
"""
DELETE FROM chat_session
WHERE persona_id IS NULL
"""
)
op.alter_column(
"chat_session",
"persona_id",

View File

@@ -16,41 +16,6 @@ class ExternalAccess:
is_public: bool
@dataclass(frozen=True)
class DocExternalAccess:
external_access: ExternalAccess
# The document ID
doc_id: str
def to_dict(self) -> dict:
return {
"external_access": {
"external_user_emails": list(self.external_access.external_user_emails),
"external_user_group_ids": list(
self.external_access.external_user_group_ids
),
"is_public": self.external_access.is_public,
},
"doc_id": self.doc_id,
}
@classmethod
def from_dict(cls, data: dict) -> "DocExternalAccess":
external_access = ExternalAccess(
external_user_emails=set(
data["external_access"].get("external_user_emails", [])
),
external_user_group_ids=set(
data["external_access"].get("external_user_group_ids", [])
),
is_public=data["external_access"]["is_public"],
)
return cls(
external_access=external_access,
doc_id=data["doc_id"],
)
@dataclass(frozen=True)
class DocumentAccess(ExternalAccess):
# User emails for Danswer users, None indicates admin

View File

@@ -1,89 +0,0 @@
import secrets
import uuid
from urllib.parse import quote
from urllib.parse import unquote
from fastapi import Request
from passlib.hash import sha256_crypt
from pydantic import BaseModel
from danswer.auth.schemas import UserRole
from danswer.configs.app_configs import API_KEY_HASH_ROUNDS
_API_KEY_HEADER_NAME = "Authorization"
# NOTE for others who are curious: In the context of a header, "X-" often refers
# to non-standard, experimental, or custom headers in HTTP or other protocols. It
# indicates that the header is not part of the official standards defined by
# organizations like the Internet Engineering Task Force (IETF).
_API_KEY_HEADER_ALTERNATIVE_NAME = "X-Danswer-Authorization"
_BEARER_PREFIX = "Bearer "
_API_KEY_PREFIX = "dn_"
_API_KEY_LEN = 192
class ApiKeyDescriptor(BaseModel):
api_key_id: int
api_key_display: str
api_key: str | None = None # only present on initial creation
api_key_name: str | None = None
api_key_role: UserRole
user_id: uuid.UUID
def generate_api_key(tenant_id: str | None = None) -> str:
# For backwards compatibility, if no tenant_id, generate old style key
if not tenant_id:
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
encoded_tenant = quote(tenant_id) # URL encode the tenant ID
return f"{_API_KEY_PREFIX}{encoded_tenant}.{secrets.token_urlsafe(_API_KEY_LEN)}"
def extract_tenant_from_api_key_header(request: Request) -> str | None:
"""Extract tenant ID from request. Returns None if auth is disabled or invalid format."""
raw_api_key_header = request.headers.get(
_API_KEY_HEADER_ALTERNATIVE_NAME
) or request.headers.get(_API_KEY_HEADER_NAME)
if not raw_api_key_header or not raw_api_key_header.startswith(_BEARER_PREFIX):
return None
api_key = raw_api_key_header[len(_BEARER_PREFIX) :].strip()
if not api_key.startswith(_API_KEY_PREFIX):
return None
parts = api_key[len(_API_KEY_PREFIX) :].split(".", 1)
if len(parts) != 2:
return None
tenant_id = parts[0]
return unquote(tenant_id) if tenant_id else None
def hash_api_key(api_key: str) -> str:
# NOTE: no salt is needed, as the API key is randomly generated
# and overlaps are impossible
return sha256_crypt.hash(api_key, salt="", rounds=API_KEY_HASH_ROUNDS)
def build_displayable_api_key(api_key: str) -> str:
if api_key.startswith(_API_KEY_PREFIX):
api_key = api_key[len(_API_KEY_PREFIX) :]
return _API_KEY_PREFIX + api_key[:4] + "********" + api_key[-4:]
def get_hashed_api_key_from_request(request: Request) -> str | None:
raw_api_key_header = request.headers.get(
_API_KEY_HEADER_ALTERNATIVE_NAME
) or request.headers.get(_API_KEY_HEADER_NAME)
if raw_api_key_header is None:
return None
if raw_api_key_header.startswith(_BEARER_PREFIX):
raw_api_key_header = raw_api_key_header[len(_BEARER_PREFIX) :].strip()
return hash_api_key(raw_api_key_header)

View File

@@ -15,7 +15,6 @@ class UserRole(str, Enum):
for all groups they are a member of
"""
LIMITED = "limited"
BASIC = "basic"
ADMIN = "admin"
CURATOR = "curator"

View File

@@ -48,11 +48,11 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.orm import attributes
from sqlalchemy.orm import Session
from danswer.auth.api_key import get_hashed_api_key_from_request
from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
@@ -75,7 +75,6 @@ from danswer.configs.constants import AuthType
from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
from danswer.db.api_key import fetch_user_for_api_key
from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
@@ -84,27 +83,24 @@ from danswer.db.auth import SQLAlchemyUserAdminDB
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.engine import get_session
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User
from danswer.db.models import UserTenantMapping
from danswer.db.users import get_user_by_email
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import async_return_default_schema
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
class BasicAuthenticationError(HTTPException):
def __init__(self, detail: str):
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
def is_user_admin(user: User | None) -> bool:
if AUTH_TYPE == AuthType.DISABLED:
return True
@@ -194,6 +190,20 @@ def verify_email_domain(email: str) -> None:
)
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
# Implement logic to get tenant_id from the mapping table
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
)
tenant_id = result.scalar_one_or_none()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id
def send_user_verification_email(
user_email: str,
token: str,
@@ -228,18 +238,19 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False,
request: Optional[Request] = None,
) -> User:
referral_source = None
if request is not None:
referral_source = request.cookies.get("referral_source", None)
try:
tenant_id = (
get_tenant_id_for_email(user_create.email)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=user_create.email,
referral_source=referral_source,
)
if not tenant_id:
raise HTTPException(
status_code=401, detail="User does not belong to an organization"
)
async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -260,7 +271,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
user = None
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
@@ -281,9 +292,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
else:
raise exceptions.UserAlreadyExists()
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return user
async def oauth_callback(
@@ -299,23 +308,19 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> models.UOAP:
referral_source = None
if request:
referral_source = getattr(request.state, "referral_source", None)
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=account_email,
referral_source=referral_source,
)
# Get tenant_id from mapping table
try:
tenant_id = (
get_tenant_id_for_email(account_email)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
if not tenant_id:
raise HTTPException(status_code=401, detail="User not found")
# Proceed with the tenant context
token = None
async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -366,9 +371,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# Explicitly set the Postgres schema for this session to ensure
# OAuth account creation happens in the correct tenant schema
await db_session.execute(text(f'SET search_path = "{tenant_id}"'))
# Add OAuth account
await self.user_db.add_oauth_account(user, oauth_account_dict)
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
await self.on_after_register(user, request)
else:
@@ -448,13 +453,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
email = credentials.username
# Get tenant_id from mapping table
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=email,
)
tenant_id = get_tenant_id_for_email(email)
if not tenant_id:
# User not found in mapping
self.password_helper.hash(credentials.password)
@@ -478,7 +477,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
has_web_login = attributes.get_attribute(user, "has_web_login")
if not has_web_login:
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
)
@@ -510,30 +510,19 @@ cookie_transport = CookieTransport(
# This strategy is used to add tenant_id to the JWT token
class TenantAwareJWTStrategy(JWTStrategy):
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=user.email,
)
async def write_token(self, user: User) -> str:
tenant_id = get_tenant_id_for_email(user.email)
data = {
"sub": str(user.id),
"aud": self.token_audience,
"tenant_id": tenant_id,
}
return data
async def write_token(self, user: User) -> str:
data = await self._create_token_data(user)
return generate_jwt(
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
)
def get_jwt_strategy() -> TenantAwareJWTStrategy:
def get_jwt_strategy() -> JWTStrategy:
return TenantAwareJWTStrategy(
secret=USER_AUTH_SECRET,
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
@@ -635,12 +624,14 @@ async def double_check_user(
return None
if user is None:
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not authenticated.",
)
if user_needs_to_be_verified() and not user.is_verified:
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not verified.",
)
@@ -649,7 +640,8 @@ async def double_check_user(
and user.oidc_expiry < datetime.now(timezone.utc)
and not include_expired
):
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User's OIDC token has expired.",
)
@@ -662,24 +654,10 @@ async def current_user_with_expired_token(
return await double_check_user(user, include_expired=True)
async def current_limited_user(
user: User | None = Depends(optional_user),
) -> User | None:
return await double_check_user(user)
async def current_user(
user: User | None = Depends(optional_user),
) -> User | None:
user = await double_check_user(user)
if not user:
return None
if user.role == UserRole.LIMITED:
raise BasicAuthenticationError(
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
)
return user
return await double_check_user(user)
async def current_curator_or_admin_user(
@@ -689,13 +667,15 @@ async def current_curator_or_admin_user(
return None
if not user or not hasattr(user, "role"):
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not authenticated or lacks role information.",
)
allowed_roles = {UserRole.GLOBAL_CURATOR, UserRole.CURATOR, UserRole.ADMIN}
if user.role not in allowed_roles:
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not a curator or admin.",
)
@@ -707,7 +687,8 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
return None
if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN:
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User must be an admin to perform this action.",
)
@@ -735,6 +716,8 @@ def generate_state_token(
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
def create_danswer_oauth_router(
oauth_client: BaseOAuth2,
backend: AuthenticationBackend,
@@ -784,22 +767,15 @@ def get_oauth_router(
response_model=OAuth2AuthorizeResponse,
)
async def authorize(
request: Request,
scopes: List[str] = Query(None),
request: Request, scopes: List[str] = Query(None)
) -> OAuth2AuthorizeResponse:
referral_source = request.cookies.get("referral_source", None)
if redirect_url is not None:
authorize_redirect_url = redirect_url
else:
authorize_redirect_url = str(request.url_for(callback_route_name))
next_url = request.query_params.get("next", "/")
state_data: Dict[str, str] = {
"next_url": next_url,
"referral_source": referral_source or "default_referral",
}
state_data: Dict[str, str] = {"next_url": next_url}
state = generate_state_token(state_data, state_secret)
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url,
@@ -858,11 +834,8 @@ def get_oauth_router(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
next_url = state_data.get("next_url", "/")
referral_source = state_data.get("referral_source", None)
request.state.referral_source = referral_source
# Proceed to authenticate or create the user
# Authenticate user
try:
user = await user_manager.oauth_callback(
oauth_client.name,
@@ -904,25 +877,7 @@ def get_oauth_router(
redirect_response.status_code = response.status_code
if hasattr(response, "media_type"):
redirect_response.media_type = response.media_type
return redirect_response
return router
def api_key_dep(
request: Request, db_session: Session = Depends(get_session)
) -> User | None:
if AUTH_TYPE == AuthType.DISABLED:
return None
hashed_api_key = get_hashed_api_key_from_request(request)
if not hashed_api_key:
raise HTTPException(status_code=401, detail="Missing API key")
if hashed_api_key:
user = fetch_user_for_api_key(hashed_api_key, db_session)
if user is None:
raise HTTPException(status_code=401, detail="Invalid API key")
return user

View File

@@ -3,7 +3,6 @@ import multiprocessing
import time
from typing import Any
import requests
import sentry_sdk
from celery import Task
from celery.app import trace
@@ -12,24 +11,18 @@ from celery.states import READY_STATES
from celery.utils.log import get_task_logger
from celery.worker import strategy # type: ignore
from sentry_sdk.integrations.celery import CeleryIntegration
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
from danswer.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.engine import get_sqlalchemy_engine
from danswer.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_document_set import RedisDocumentSet
from danswer.db.engine import get_all_tenant_ids
from danswer.redis.redis_pool import get_redis_client
from danswer.redis.redis_usergroup import RedisUserGroup
from danswer.utils.logger import ColoredFormatter
from danswer.utils.logger import PlainFormatter
from danswer.utils.logger import setup_logger
@@ -115,43 +108,29 @@ def on_task_postrun(
if task_id.startswith(RedisDocumentSet.PREFIX):
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
if document_set_id is not None:
rds = RedisDocumentSet(tenant_id, int(document_set_id))
rds = RedisDocumentSet(int(document_set_id))
r.srem(rds.taskset_key, task_id)
return
if task_id.startswith(RedisUserGroup.PREFIX):
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
if usergroup_id is not None:
rug = RedisUserGroup(tenant_id, int(usergroup_id))
rug = RedisUserGroup(int(usergroup_id))
r.srem(rug.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorDelete.PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if task_id.startswith(RedisConnectorDeletion.PREFIX):
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
if cc_pair_id is not None:
RedisConnectorDelete.remove_from_taskset(int(cc_pair_id), task_id, r)
rcd = RedisConnectorDeletion(int(cc_pair_id))
r.srem(rcd.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorPrune.SUBTASK_PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
if cc_pair_id is not None:
RedisConnectorPrune.remove_from_taskset(int(cc_pair_id), task_id, r)
return
if task_id.startswith(RedisConnectorPermissionSync.SUBTASK_PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if cc_pair_id is not None:
RedisConnectorPermissionSync.remove_from_taskset(
int(cc_pair_id), task_id, r
)
return
if task_id.startswith(RedisConnectorExternalGroupSync.SUBTASK_PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if cc_pair_id is not None:
RedisConnectorExternalGroupSync.remove_from_taskset(
int(cc_pair_id), task_id, r
)
rcp = RedisConnectorPruning(int(cc_pair_id))
r.srem(rcp.taskset_key, task_id)
return
@@ -161,154 +140,27 @@ def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
"""Waits for redis to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
r = get_redis_client(tenant_id=None)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Redis: Readiness probe starting.")
logger.info("Redis: Readiness check starting.")
while True:
try:
if r.ping():
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > WAIT_LIMIT:
break
logger.info(
f"Redis: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
time.sleep(WAIT_INTERVAL)
if not ready:
msg = (
f"Redis: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Redis: Readiness probe succeeded. Continuing...")
return
def wait_for_db(sender: Any, **kwargs: Any) -> None:
"""Waits for the db to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Database: Readiness probe starting.")
while True:
try:
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(text("SELECT NOW()")).scalar()
if result:
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > WAIT_LIMIT:
break
logger.info(
f"Database: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
time.sleep(WAIT_INTERVAL)
if not ready:
msg = (
f"Database: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Database: Readiness probe succeeded. Continuing...")
return
def wait_for_vespa(sender: Any, **kwargs: Any) -> None:
"""Waits for Vespa to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Vespa: Readiness probe starting.")
while True:
try:
response = requests.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
response.raise_for_status()
response_dict = response.json()
if response_dict["status"]["code"] == "up":
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > WAIT_LIMIT:
break
logger.info(
f"Vespa: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
time.sleep(WAIT_INTERVAL)
if not ready:
msg = (
f"Vespa: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Vespa: Readiness probe succeeded. Continuing...")
return
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("Running as a secondary celery worker.")
# Set up variables for waiting on primary worker
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
r = get_redis_client(tenant_id=None)
time_start = time.monotonic()
logger.info("Waiting for primary worker to be ready...")
while True:
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
break
time_elapsed = time.monotonic() - time_start
logger.info(
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Primary worker was not ready within the timeout. "
f"Redis: Readiness check did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
@@ -316,7 +168,57 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
time.sleep(WAIT_INTERVAL)
logger.info("Wait for primary worker completed successfully. Continuing...")
logger.info("Redis: Readiness check succeeded. Continuing...")
return
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
logger.info("Running as a secondary celery worker.")
logger.info("Waiting for all tenant primary workers to be ready...")
time_start = time.monotonic()
while True:
tenant_ids = get_all_tenant_ids()
# Check if we have a primary worker lock for each tenant
all_tenants_ready = all(
get_redis_client(tenant_id=tenant_id).exists(
DanswerRedisLocks.PRIMARY_WORKER
)
for tenant_id in tenant_ids
)
if all_tenants_ready:
break
time_elapsed = time.monotonic() - time_start
ready_tenants = sum(
1
for tenant_id in tenant_ids
if get_redis_client(tenant_id=tenant_id).exists(
DanswerRedisLocks.PRIMARY_WORKER
)
)
logger.info(
f"Not all tenant primary workers are ready yet. "
f"Ready tenants: {ready_tenants}/{len(tenant_ids)} "
f"elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Not all tenant primary workers were ready within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
time.sleep(WAIT_INTERVAL)
logger.info("All tenant primary workers are ready. Continuing...")
return
@@ -328,20 +230,26 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
if not celery_is_worker_primary(sender):
return
if not sender.primary_worker_lock:
if not hasattr(sender, "primary_worker_locks"):
return
logger.info("Releasing primary worker lock.")
lock = sender.primary_worker_lock
try:
if lock.owned():
try:
lock.release()
sender.primary_worker_lock = None
except Exception as e:
logger.error(f"Failed to release primary worker lock: {e}")
except Exception as e:
logger.error(f"Failed to check if primary worker lock is owned: {e}")
for tenant_id, lock in sender.primary_worker_locks.items():
try:
if lock and lock.owned():
logger.debug(f"Attempting to release lock for tenant {tenant_id}")
try:
lock.release()
logger.debug(f"Successfully released lock for tenant {tenant_id}")
except Exception as e:
logger.error(
f"Failed to release lock for tenant {tenant_id}. Error: {str(e)}"
)
finally:
sender.primary_worker_locks[tenant_id] = None
except Exception as e:
logger.error(
f"Error checking lock status for tenant {tenant_id}. Error: {str(e)}"
)
def on_setup_logging(

View File

@@ -3,162 +3,26 @@ from typing import Any
from celery import Celery
from celery import signals
from celery.beat import PersistentScheduler # type: ignore
from celery.signals import beat_init
import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
from shared_configs.configs import MULTI_TENANT
logger = setup_logger(__name__)
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("danswer.background.celery.configs.beat")
class DynamicTenantScheduler(PersistentScheduler):
def __init__(self, *args: Any, **kwargs: Any) -> None:
logger.info("Initializing DynamicTenantScheduler")
super().__init__(*args, **kwargs)
self._reload_interval = timedelta(minutes=2)
self._last_reload = self.app.now() - self._reload_interval
# Let the parent class handle store initialization
self.setup_schedule()
self._update_tenant_tasks()
logger.info(f"Set reload interval to {self._reload_interval}")
def setup_schedule(self) -> None:
logger.info("Setting up initial schedule")
super().setup_schedule()
logger.info("Initial schedule setup complete")
def tick(self) -> float:
retval = super().tick()
now = self.app.now()
if (
self._last_reload is None
or (now - self._last_reload) > self._reload_interval
):
logger.info("Reload interval reached, initiating tenant task update")
self._update_tenant_tasks()
self._last_reload = now
logger.info("Tenant task update completed, reset reload timer")
return retval
def _update_tenant_tasks(self) -> None:
logger.info("Starting tenant task update process")
try:
logger.info("Fetching all tenant IDs")
tenant_ids = get_all_tenant_ids()
logger.info(f"Found {len(tenant_ids)} tenants")
logger.info("Fetching tasks to schedule")
tasks_to_schedule = fetch_versioned_implementation(
"danswer.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
)
new_beat_schedule: dict[str, dict[str, Any]] = {}
current_schedule = self.schedule.items()
existing_tenants = set()
for task_name, _ in current_schedule:
if "-" in task_name:
existing_tenants.add(task_name.split("-")[-1])
logger.info(f"Found {len(existing_tenants)} existing tenants in schedule")
for tenant_id in tenant_ids:
if (
IGNORED_SYNCING_TENANT_LIST
and tenant_id in IGNORED_SYNCING_TENANT_LIST
):
logger.info(
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
)
continue
if tenant_id not in existing_tenants:
logger.info(f"Processing new tenant: {tenant_id}")
for task in tasks_to_schedule():
task_name = f"{task['name']}-{tenant_id}"
logger.debug(f"Creating task configuration for {task_name}")
new_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": {"tenant_id": tenant_id},
}
if options := task.get("options"):
logger.debug(f"Adding options to task {task_name}: {options}")
new_task["options"] = options
new_beat_schedule[task_name] = new_task
if self._should_update_schedule(current_schedule, new_beat_schedule):
logger.info(
"Schedule update required",
extra={
"new_tasks": len(new_beat_schedule),
"current_tasks": len(current_schedule),
},
)
# Create schedule entries
entries = {}
for name, entry in new_beat_schedule.items():
entries[name] = self.Entry(
name=name,
app=self.app,
task=entry["task"],
schedule=entry["schedule"],
options=entry.get("options", {}),
kwargs=entry.get("kwargs", {}),
)
# Update the schedule using the scheduler's methods
self.schedule.clear()
self.schedule.update(entries)
# Ensure changes are persisted
self.sync()
logger.info("Schedule update completed successfully")
else:
logger.info("Schedule is up to date, no changes needed")
except (AttributeError, KeyError):
logger.exception("Failed to process task configuration")
except Exception:
logger.exception("Unexpected error updating tenant tasks")
def _should_update_schedule(
self, current_schedule: dict, new_schedule: dict
) -> bool:
"""Compare schedules to determine if an update is needed."""
logger.debug("Comparing current and new schedules")
current_tasks = set(name for name, _ in current_schedule)
new_tasks = set(new_schedule.keys())
needs_update = current_tasks != new_tasks
logger.debug(f"Schedule update needed: {needs_update}")
return needs_update
@beat_init.connect
def on_beat_init(sender: Any, **kwargs: Any) -> None:
logger.info("beat_init signal received.")
# Celery beat shouldn't touch the db at all. But just setting a low minimum here.
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
SqlEngine.init_engine(pool_size=2, max_overflow=0)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
@@ -169,4 +33,68 @@ def on_setup_logging(
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
celery_app.conf.beat_scheduler = DynamicTenantScheduler
#####
# Celery Beat (Periodic Tasks) Settings
#####
tenant_ids = get_all_tenant_ids()
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": "check_for_vespa_sync_task",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-connector-deletion",
"task": "check_for_connector_deletion_task",
"schedule": timedelta(seconds=60),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-indexing",
"task": "check_for_indexing",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": "check_for_pruning",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "kombu-message-cleanup",
"task": "kombu_message_cleanup_task",
"schedule": timedelta(seconds=3600),
"options": {"priority": DanswerCeleryPriority.LOWEST},
},
{
"name": "monitor-vespa-sync",
"task": "monitor_vespa_sync",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
]
# Build the celery beat schedule dynamically
beat_schedule = {}
for tenant_id in tenant_ids:
for task in tasks_to_schedule:
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
beat_schedule[task_name] = {
"task": task["task"],
"schedule": task["schedule"],
"options": task["options"],
"kwargs": {"tenant_id": tenant_id}, # Must pass tenant_id as an argument
}
# Include any existing beat schedules
existing_beat_schedule = celery_app.conf.beat_schedule or {}
beat_schedule.update(existing_beat_schedule)
# Update the Celery app configuration once
celery_app.conf.beat_schedule = beat_schedule

View File

@@ -13,7 +13,6 @@ import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -59,15 +58,9 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=4, max_overflow=12)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
SqlEngine.init_engine(pool_size=8, max_overflow=0)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.on_secondary_worker_init(sender, **kwargs)
@@ -91,7 +84,5 @@ def on_setup_logging(
celery_app.autodiscover_tasks(
[
"danswer.background.celery.tasks.pruning",
"danswer.background.celery.tasks.doc_permission_syncing",
"danswer.background.celery.tasks.external_group_syncing",
]
)

View File

@@ -6,7 +6,6 @@ from celery import signals
from celery import Task
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
@@ -14,7 +13,6 @@ import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -60,15 +58,9 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
SqlEngine.init_engine(pool_size=8, max_overflow=0)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.on_secondary_worker_init(sender, **kwargs)
@@ -82,11 +74,6 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@worker_process_init.connect
def init_worker(**kwargs: Any) -> None:
SqlEngine.reset_engine()
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any

View File

@@ -13,7 +13,6 @@ import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -60,13 +59,8 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.on_secondary_worker_init(sender, **kwargs)
@@ -91,7 +85,5 @@ celery_app.autodiscover_tasks(
[
"danswer.background.celery.tasks.shared",
"danswer.background.celery.tasks.vespa",
"danswer.background.celery.tasks.connector_deletion",
"danswer.background.celery.tasks.doc_permission_syncing",
]
)

View File

@@ -13,23 +13,21 @@ from celery.signals import worker_shutdown
import danswer.background.celery.apps.app_base as app_base
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import SqlEngine
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_connector_stop import RedisConnectorStop
from danswer.redis.redis_document_set import RedisDocumentSet
from danswer.redis.redis_pool import get_redis_client
from danswer.redis.redis_usergroup import RedisUserGroup
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -77,68 +75,108 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
logger.info("Running as the primary celery worker.")
sender.primary_worker_locks = {}
# This is singleton work that should be done on startup exactly once
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_redis_client(tenant_id=None)
# by the primary worker
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
r = get_redis_client(tenant_id=tenant_id)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
# this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
# this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")
logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")
# tacking on our own user data to the sender
sender.primary_worker_lock = lock
# tacking on our own user data to the sender
sender.primary_worker_locks[tenant_id] = lock
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
RedisDocumentSet.reset_all(r)
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
r.delete(key)
RedisUserGroup.reset_all(r)
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
r.delete(key)
RedisConnectorDelete.reset_all(r)
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
r.delete(key)
RedisConnectorPrune.reset_all(r)
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
r.delete(key)
RedisConnectorIndex.reset_all(r)
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
r.delete(key)
RedisConnectorStop.reset_all(r)
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
r.delete(key)
RedisConnectorPermissionSync.reset_all(r)
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
r.delete(key)
RedisConnectorExternalGroupSync.reset_all(r)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"):
r.delete(key)
# @worker_process_init.connect
# def on_worker_process_init(sender: Any, **kwargs: Any) -> None:
# """This only runs inside child processes when the worker is in pool=prefork mode.
# This may be technically unnecessary since we're finding prefork pools to be
# unstable and currently aren't planning on using them."""
# logger.info("worker_process_init signal received.")
# SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
# SqlEngine.init_engine(pool_size=5, max_overflow=0)
# # https://stackoverflow.com/questions/43944787/sqlalchemy-celery-with-scoped-session-error
# SqlEngine.get_engine().dispose(close=False)
@worker_ready.connect
@@ -191,36 +229,52 @@ class HubPeriodicTask(bootsteps.StartStopStep):
if not celery_is_worker_primary(worker):
return
if not hasattr(worker, "primary_worker_lock"):
if not hasattr(worker, "primary_worker_locks"):
return
lock = worker.primary_worker_lock
# Retrieve all tenant IDs
tenant_ids = get_all_tenant_ids()
r = get_redis_client(tenant_id=None)
for tenant_id in tenant_ids:
lock = worker.primary_worker_locks.get(tenant_id)
if not lock:
continue # Skip if no lock for this tenant
if lock.owned():
task_logger.debug("Reacquiring primary worker lock.")
lock.reacquire()
else:
task_logger.warning(
"Full acquisition of primary worker lock. "
"Reasons could be worker restart or lock expiration."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
r = get_redis_client(tenant_id=tenant_id)
task_logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info("Primary worker lock: Acquire succeeded.")
worker.primary_worker_lock = lock
if lock.owned():
task_logger.debug(
f"Reacquiring primary worker lock for tenant {tenant_id}."
)
lock.reacquire()
else:
task_logger.error("Primary worker lock: Acquire failed!")
raise TimeoutError("Primary worker lock could not be acquired!")
task_logger.warning(
f"Full acquisition of primary worker lock for tenant {tenant_id}. "
"Reasons could be worker restart or lock expiration."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire starting."
)
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire succeeded."
)
worker.primary_worker_locks[tenant_id] = lock
else:
task_logger.error(
f"Primary worker lock for tenant {tenant_id}: Acquire failed!"
)
raise TimeoutError(
f"Primary worker lock for tenant {tenant_id} could not be acquired!"
)
except Exception:
task_logger.exception("Periodic task failed.")
@@ -239,8 +293,6 @@ celery_app.autodiscover_tasks(
"danswer.background.celery.tasks.connector_deletion",
"danswer.background.celery.tasks.indexing",
"danswer.background.celery.tasks.periodic",
"danswer.background.celery.tasks.doc_permission_syncing",
"danswer.background.celery.tasks.external_group_syncing",
"danswer.background.celery.tasks.pruning",
"danswer.background.celery.tasks.shared",
"danswer.background.celery.tasks.vespa",

View File

@@ -1,10 +1,568 @@
# These are helper objects for tracking the keys we need to write in redis
import time
from abc import ABC
from abc import abstractmethod
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.configs.base import CELERY_SEPARATOR
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import construct_document_select_for_connector_credential_pair
from danswer.db.document import (
construct_document_select_for_connector_credential_pair_by_needs_sync,
)
from danswer.db.document_set import construct_document_select_by_docset
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import global_version
class RedisObjectHelper(ABC):
PREFIX = "base"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: str):
self._id: str = id
@property
def task_id_prefix(self) -> str:
return f"{self.PREFIX}_{self._id}"
@property
def fence_key(self) -> str:
# example: documentset_fence_1
return f"{self.FENCE_PREFIX}_{self._id}"
@property
def taskset_key(self) -> str:
# example: documentset_taskset_1
return f"{self.TASKSET_PREFIX}_{self._id}"
@staticmethod
def get_id_from_fence_key(key: str) -> str | None:
"""
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
Args:
key (str): The fence key string.
Returns:
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
"""
parts = key.split("_")
if len(parts) != 3:
return None
object_id = parts[2]
return object_id
@staticmethod
def get_id_from_task_id(task_id: str) -> str | None:
"""
Extracts the object ID from a task ID string.
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
- `objectid` is the ID you want to extract,
- `suffix` is another arbitrary string (e.g., a UUID).
Example:
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
this method will return the string `"1"`.
Args:
task_id (str): The task ID string from which to extract the object ID.
Returns:
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
"""
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
parts = task_id.split("_")
if len(parts) != 3:
return None
object_id = parts[1]
return object_id
@abstractmethod
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
pass
class RedisDocumentSet(RedisObjectHelper):
PREFIX = "documentset"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
stmt = construct_document_select_by_docset(int(self._id), current_only=False)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
async_results.append(result)
return len(async_results)
class RedisUserGroup(RedisObjectHelper):
PREFIX = "usergroup"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
if not global_version.is_ee_version():
return 0
try:
construct_document_select_by_usergroup = fetch_versioned_implementation(
"danswer.db.user_group",
"construct_document_select_by_usergroup",
)
except ModuleNotFoundError:
return 0
stmt = construct_document_select_by_usergroup(int(self._id))
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
async_results.append(result)
return len(async_results)
class RedisConnectorCredentialPair(RedisObjectHelper):
"""This class is used to scan documents by cc_pair in the db and collect them into
a unified set for syncing.
It differs from the other redis helpers in that the taskset used spans
all connectors and is not per connector."""
PREFIX = "connectorsync"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
@classmethod
def get_fence_key(cls) -> str:
return RedisConnectorCredentialPair.FENCE_PREFIX
@classmethod
def get_taskset_key(cls) -> str:
return RedisConnectorCredentialPair.TASKSET_PREFIX
@property
def taskset_key(self) -> str:
"""Notice that this is intentionally reusing the same taskset for all
connector syncs"""
# example: connector_taskset
return f"{self.TASKSET_PREFIX}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
class RedisConnectorDeletion(RedisObjectHelper):
PREFIX = "connectordeletion"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
"""Returns None if the cc_pair doesn't exist.
Otherwise, returns an int with the number of generated tasks."""
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
stmt = construct_document_select_for_connector_credential_pair(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(self.taskset_key, custom_task_id)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"document_by_cc_pair_cleanup_task",
kwargs=dict(
document_id=doc.id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
class RedisConnectorPruning(RedisObjectHelper):
"""Celery will kick off a long running generator task to crawl the connector and
find any missing docs, which will each then get a new cleanup task. The progress of
those tasks will then be monitored to completion.
Example rough happy path order:
Check connectorpruning_fence_1
Send generator task with id connectorpruning+generator_1_{uuid}
generator runs connector with callbacks that increment connectorpruning_generator_progress_1
generator creates many subtasks with id connectorpruning+sub_1_{uuid}
in taskset connectorpruning_taskset_1
on completion, generator sets connectorpruning_generator_complete_1
celery postrun removes subtasks from taskset
monitor beat task cleans up when taskset reaches 0 items
"""
PREFIX = "connectorpruning"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire pruning process
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
SUBTASK_PREFIX = PREFIX + "+sub"
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # a signal that contains generator progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # a signal that the generator has finished
def __init__(self, id: int) -> None:
super().__init__(str(id))
self.documents_to_prune: set[str] = set()
@property
def generator_task_id_prefix(self) -> str:
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
@property
def generator_progress_key(self) -> str:
# example: connectorpruning_generator_progress_1
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
@property
def generator_complete_key(self) -> str:
# example: connectorpruning_generator_complete_1
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
@property
def subtask_id_prefix(self) -> str:
return f"{self.SUBTASK_PREFIX}_{self._id}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
for doc_id in self.documents_to_prune:
current_time = time.monotonic()
if lock and current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.subtask_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(self.taskset_key, custom_task_id)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"document_by_cc_pair_cleanup_task",
kwargs=dict(
document_id=doc_id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
def is_pruning(self, redis_client: Redis) -> bool:
"""A single example of a helper method being refactored into the redis helper"""
if redis_client.exists(self.fence_key):
return True
return False
class RedisConnectorIndexing(RedisObjectHelper):
"""Celery will kick off a long running indexing task to crawl the connector and
find any new or updated docs docs, which will each then get a new sync task or be
indexed inline.
ID should be a concatenation of cc_pair_id and search_setting_id, delimited by "/".
e.g. "2/5"
"""
PREFIX = "connectorindexing"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
SUBTASK_PREFIX = PREFIX + "+sub"
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # a signal that contains generator progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # a signal that the generator has finished
def __init__(self, cc_pair_id: int, search_settings_id: int) -> None:
super().__init__(f"{cc_pair_id}/{search_settings_id}")
@property
def generator_lock_key(self) -> str:
return f"{self.GENERATOR_LOCK_PREFIX}_{self._id}"
@property
def generator_task_id_prefix(self) -> str:
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
@property
def generator_progress_key(self) -> str:
# example: connectorpruning_generator_progress_1
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
@property
def generator_complete_key(self) -> str:
# example: connectorpruning_generator_complete_1
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
@property
def subtask_id_prefix(self) -> str:
return f"{self.SUBTASK_PREFIX}_{self._id}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
return None
def is_indexing(self, redis_client: Redis) -> bool:
"""A single example of a helper method being refactored into the redis helper"""
if redis_client.exists(self.fence_key):
return True
return False
class RedisConnectorStop(RedisObjectHelper):
"""Used to signal any running tasks for a connector to stop. We should refactor
connector related redis helpers into a single class.
"""
PREFIX = "connectorstop"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
return None
def celery_get_queue_length(queue: str, r: Redis) -> int:

View File

@@ -4,6 +4,7 @@ from typing import Any
from sqlalchemy.orm import Session
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
@@ -17,7 +18,7 @@ from danswer.connectors.models import Document
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.enums import TaskStatus
from danswer.db.models import TaskQueueState
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
from danswer.server.documents.models import DeletionAttemptSnapshot
from danswer.utils.logger import setup_logger
@@ -40,14 +41,14 @@ def _get_deletion_status(
if not cc_pair:
return None
redis_connector = RedisConnector(tenant_id, cc_pair.id)
if not redis_connector.delete.fenced:
rcd = RedisConnectorDeletion(cc_pair.id)
r = get_redis_client(tenant_id=tenant_id)
if not r.exists(rcd.fence_key):
return None
return TaskQueueState(
task_id="",
task_name=redis_connector.delete.fence_key,
status=TaskStatus.STARTED,
task_id="", task_name=rcd.fence_key, status=TaskStatus.STARTED
)
@@ -81,7 +82,7 @@ def extract_ids_from_runnable_connector(
callback: RunIndexingCallbackInterface | None = None,
) -> set[str]:
"""
If the SlimConnector hasnt been implemented for the given connector, just pull
If the PruneConnector hasnt been implemented for the given connector, just pull
all docs using the load_from_state and grab out the IDs.
Optionally, a callback can be passed to handle the length of each document batch.

View File

@@ -1,60 +0,0 @@
from datetime import timedelta
from typing import Any
from danswer.configs.constants import DanswerCeleryPriority
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": "check_for_vespa_sync_task",
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-connector-deletion",
"task": "check_for_connector_deletion_task",
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-indexing",
"task": "check_for_indexing",
"schedule": timedelta(seconds=15),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": "check_for_pruning",
"schedule": timedelta(seconds=15),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "kombu-message-cleanup",
"task": "kombu_message_cleanup_task",
"schedule": timedelta(seconds=3600),
"options": {"priority": DanswerCeleryPriority.LOWEST},
},
{
"name": "monitor-vespa-sync",
"task": "monitor_vespa_sync",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-doc-permissions-sync",
"task": "check_for_doc_permissions_sync",
"schedule": timedelta(seconds=30),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-external-group-sync",
"task": "check_for_external_group_sync",
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
]
def get_tasks_to_schedule() -> list[dict[str, Any]]:
return tasks_to_schedule

View File

@@ -10,6 +10,13 @@ from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
RedisConnectorDeletionFenceData,
)
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
@@ -18,8 +25,6 @@ from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.search_settings import get_all_search_settings
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_delete import RedisConnectorDeletionFenceData
from danswer.redis.redis_pool import get_redis_client
@@ -57,7 +62,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
# try running cleanup on the cc_pair_ids
for cc_pair_id in cc_pair_ids:
with get_session_with_tenant(tenant_id) as db_session:
redis_connector = RedisConnector(tenant_id, cc_pair_id)
rcs = RedisConnectorStop(cc_pair_id)
try:
try_generate_document_cc_pair_cleanup_tasks(
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
@@ -66,10 +71,10 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
# this means we wanted to start deleting but dependent tasks were running
# Leave a stop signal to clear indexing and pruning tasks more quickly
task_logger.info(str(e))
redis_connector.stop.set_fence(True)
r.set(rcs.fence_key, cc_pair_id)
else:
# clear the stop signal if it exists ... no longer needed
redis_connector.stop.set_fence(False)
r.delete(rcs.fence_key)
except SoftTimeLimitExceeded:
task_logger.info(
@@ -101,10 +106,10 @@ def try_generate_document_cc_pair_cleanup_tasks(
lock_beat.reacquire()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
rcd = RedisConnectorDeletion(cc_pair_id)
# don't generate sync tasks if tasks are still pending
if redis_connector.delete.fenced:
if r.exists(rcd.fence_key):
return None
# we need to load the state of the object inside the fence
@@ -118,55 +123,47 @@ def try_generate_document_cc_pair_cleanup_tasks(
return None
# set a basic fence to start
fence_payload = RedisConnectorDeletionFenceData(
fence_value = RedisConnectorDeletionFenceData(
num_tasks=None,
submitted=datetime.now(timezone.utc),
)
redis_connector.delete.set_fence(fence_payload)
r.set(rcd.fence_key, fence_value.model_dump_json())
try:
# do not proceed if connector indexing or connector pruning are running
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(search_settings.id)
if redis_connector_index.fenced:
rci = RedisConnectorIndexing(cc_pair_id, search_settings.id)
if r.get(rci.fence_key):
raise TaskDependencyError(
f"Connector deletion - Delayed (indexing in progress): "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings.id}"
)
if redis_connector.prune.fenced:
rcp = RedisConnectorPruning(cc_pair_id)
if r.get(rcp.fence_key):
raise TaskDependencyError(
f"Connector deletion - Delayed (pruning in progress): "
f"cc_pair={cc_pair_id}"
)
if redis_connector.permissions.fenced:
raise TaskDependencyError(
f"Connector deletion - Delayed (permissions in progress): "
f"cc_pair={cc_pair_id}"
)
# add tasks to celery and build up the task set to monitor in redis
redis_connector.delete.taskset_clear()
r.delete(rcd.taskset_key)
# Add all documents that need to be updated into the queue
task_logger.info(
f"RedisConnectorDeletion.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = redis_connector.delete.generate_tasks(
app, db_session, lock_beat
)
tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id)
if tasks_generated is None:
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
except TaskDependencyError:
redis_connector.delete.set_fence(None)
r.delete(rcd.fence_key)
raise
except Exception:
task_logger.exception("Unexpected exception")
redis_connector.delete.set_fence(None)
r.delete(rcd.fence_key)
return None
else:
# Currently we are allowing the sync to proceed with 0 tasks.
@@ -181,7 +178,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
)
# set this only after all tasks have been added
fence_payload.num_tasks = tasks_generated
redis_connector.delete.set_fence(fence_payload)
fence_value.num_tasks = tasks_generated
r.set(rcd.fence_key, fence_value.model_dump_json())
return tasks_generated

View File

@@ -1,321 +0,0 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from uuid import uuid4
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from danswer.access.models import DocExternalAccess
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import DocumentSource
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncData,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import doc_permission_sync_ctx
from danswer.utils.logger import setup_logger
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.danswer.db.document import upsert_document_external_perms
from ee.danswer.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
logger = setup_logger()
DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES = 3
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
LIGHT_SOFT_TIME_LIMIT = 105
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if external doc permissions sync is due."""
if cc_pair.access_type != AccessType.SYNC:
return False
# skip doc permissions sync if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return False
# If the last sync is None, it has never been run so we run the sync
last_perm_sync = cc_pair.last_time_perm_sync
if last_perm_sync is None:
return True
source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
if not source_sync_period:
return True
# If the last sync is greater than the full fetch period, we run the sync
next_sync = last_perm_sync + timedelta(seconds=source_sync_period)
if datetime.now(timezone.utc) >= next_sync:
return True
return False
@shared_task(
name="check_for_doc_permissions_sync",
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
# get all cc pairs that need to be synced
cc_pair_ids_to_sync: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
if _is_external_doc_permissions_sync_due(cc_pair):
cc_pair_ids_to_sync.append(cc_pair.id)
for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_permissions_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not tasks_created:
continue
task_logger.info(f"Doc permissions sync queued: cc_pair={cc_pair_id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
def try_creating_permissions_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
tenant_id: str | None,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Returns None if no syncing is required."""
redis_connector = RedisConnector(tenant_id, cc_pair_id)
LOCK_TIMEOUT = 30
lock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
try:
if redis_connector.permissions.fenced:
return None
if redis_connector.delete.fenced:
return None
if redis_connector.prune.fenced:
return None
redis_connector.permissions.generator_clear()
redis_connector.permissions.taskset_clear()
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
app.send_task(
"connector_permission_sync_generator_task",
kwargs=dict(
cc_pair_id=cc_pair_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.HIGH,
)
# set a basic fence to start
payload = RedisConnectorPermissionSyncData(
started=None,
)
redis_connector.permissions.set_fence(payload)
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
return None
finally:
if lock.owned():
lock.release()
return 1
@shared_task(
name="connector_permission_sync_generator_task",
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,
trail=False,
bind=True,
)
def connector_permission_sync_generator_task(
self: Task,
cc_pair_id: int,
tenant_id: str | None,
) -> None:
"""
Permission sync task that handles document permission syncing for a given connector credential pair
This task assumes that the task has already been properly fenced
"""
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
doc_permission_sync_ctx_dict["cc_pair_id"] = cc_pair_id
doc_permission_sync_ctx_dict["request_id"] = self.request.id
doc_permission_sync_ctx.set(doc_permission_sync_ctx_dict)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
lock = r.lock(
DanswerRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"Permission sync task already running, exiting...: cc_pair={cc_pair_id}"
)
return None
try:
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if cc_pair is None:
raise ValueError(
f"No connector credential pair found for id: {cc_pair_id}"
)
source_type = cc_pair.connector.source
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
if doc_sync_func is None:
raise ValueError(f"No doc sync func found for {source_type}")
logger.info(f"Syncing docs for {source_type}")
payload = RedisConnectorPermissionSyncData(
started=datetime.now(timezone.utc),
)
redis_connector.permissions.set_fence(payload)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
task_logger.info(
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = redis_connector.permissions.generate_tasks(
self.app, lock, document_external_accesses, source_type
)
if tasks_generated is None:
return None
task_logger.info(
f"RedisConnector.permissions.generate_tasks finished. "
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
)
redis_connector.permissions.generator_complete = tasks_generated
except Exception as e:
task_logger.exception(f"Failed to run permission sync: cc_pair={cc_pair_id}")
redis_connector.permissions.generator_clear()
redis_connector.permissions.taskset_clear()
redis_connector.permissions.set_fence(None)
raise e
finally:
if lock.owned():
lock.release()
@shared_task(
name="update_external_document_permissions_task",
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,
max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
bind=True,
)
def update_external_document_permissions_task(
self: Task,
tenant_id: str | None,
serialized_doc_external_access: dict,
source_string: str,
) -> bool:
document_external_access = DocExternalAccess.from_dict(
serialized_doc_external_access
)
doc_id = document_external_access.doc_id
external_access = document_external_access.external_access
try:
with get_session_with_tenant(tenant_id) as db_session:
# Then we build the update requests to update vespa
batch_add_non_web_user_if_not_exists(
db_session=db_session,
emails=list(external_access.external_user_emails),
)
upsert_document_external_perms(
db_session=db_session,
doc_id=doc_id,
external_access=external_access,
source_type=DocumentSource(source_string),
)
logger.debug(
f"Successfully synced postgres document permissions for {doc_id}"
)
return True
except Exception:
logger.exception("Error Syncing Document Permissions")
return False

View File

@@ -1,265 +0,0 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from uuid import uuid4
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector import mark_cc_pair_as_external_group_synced
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIOD
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
logger = setup_logger()
EXTERNAL_GROUPS_UPDATE_MAX_RETRIES = 3
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
LIGHT_SOFT_TIME_LIMIT = 105
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if external group sync is due."""
if cc_pair.access_type != AccessType.SYNC:
return False
# skip pruning if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return False
# If there is not group sync function for the connector, we don't run the sync
# This is fine because all sources dont necessarily have a concept of groups
if not GROUP_PERMISSIONS_FUNC_MAP.get(cc_pair.connector.source):
return False
# If the last sync is None, it has never been run so we run the sync
last_ext_group_sync = cc_pair.last_time_external_group_sync
if last_ext_group_sync is None:
return True
source_sync_period = EXTERNAL_GROUP_SYNC_PERIOD
# If EXTERNAL_GROUP_SYNC_PERIOD is None, we always run the sync.
if not source_sync_period:
return True
# If the last sync is greater than the full fetch period, we run the sync
next_sync = last_ext_group_sync + timedelta(seconds=source_sync_period)
if datetime.now(timezone.utc) >= next_sync:
return True
return False
@shared_task(
name="check_for_external_group_sync",
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
cc_pair_ids_to_sync: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
if _is_external_group_sync_due(cc_pair):
cc_pair_ids_to_sync.append(cc_pair.id)
for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_permissions_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not tasks_created:
continue
task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
def try_creating_permissions_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
tenant_id: str | None,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Returns None if no syncing is required."""
redis_connector = RedisConnector(tenant_id, cc_pair_id)
LOCK_TIMEOUT = 30
lock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_external_group_sync_tasks",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
try:
# Dont kick off a new sync if the previous one is still running
if redis_connector.external_group_sync.fenced:
return None
redis_connector.external_group_sync.generator_clear()
redis_connector.external_group_sync.taskset_clear()
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
_ = app.send_task(
"connector_external_group_sync_generator_task",
kwargs=dict(
cc_pair_id=cc_pair_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.HIGH,
)
# set a basic fence to start
redis_connector.external_group_sync.set_fence(True)
except Exception:
task_logger.exception(
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
)
return None
finally:
if lock.owned():
lock.release()
return 1
@shared_task(
name="connector_external_group_sync_generator_task",
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,
trail=False,
bind=True,
)
def connector_external_group_sync_generator_task(
self: Task,
cc_pair_id: int,
tenant_id: str | None,
) -> None:
"""
Permission sync task that handles document permission syncing for a given connector credential pair
This task assumes that the task has already been properly fenced
"""
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
lock = r.lock(
DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
)
try:
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
)
return None
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if cc_pair is None:
raise ValueError(
f"No connector credential pair found for id: {cc_pair_id}"
)
source_type = cc_pair.connector.source
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
if ext_group_sync_func is None:
raise ValueError(f"No external group sync func found for {source_type}")
logger.info(f"Syncing docs for {source_type}")
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
logger.info(
f"Syncing {len(external_user_groups)} external user groups for {source_type}"
)
replace_user__ext_group_for_cc_pair(
db_session=db_session,
cc_pair_id=cc_pair.id,
group_defs=external_user_groups,
source=cc_pair.connector.source,
)
logger.info(
f"Synced {len(external_user_groups)} external user groups for {source_type}"
)
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
except Exception as e:
task_logger.exception(
f"Failed to run external group sync: cc_pair={cc_pair_id}"
)
redis_connector.external_group_sync.generator_clear()
redis_connector.external_group_sync.taskset_clear()
raise e
finally:
# we always want to clear the fence after the task is done or failed so it doesn't get stuck
redis_connector.external_group_sync.set_fence(False)
if lock.owned():
lock.release()

View File

@@ -2,9 +2,10 @@ from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from time import sleep
from typing import cast
from uuid import uuid4
import redis
import sentry_sdk
from celery import Celery
from celery import shared_task
from celery import Task
@@ -13,6 +14,12 @@ from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
RedisConnectorIndexingFenceData,
)
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
@@ -43,15 +50,12 @@ from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_index import RedisConnectorIndexingFenceData
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
@@ -101,22 +105,19 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
return None
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
old_search_settings = check_index_swap(db_session=db_session)
check_index_swap(db_session=db_session)
current_search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
if current_search_settings.provider_type is None and not MULTI_TENANT:
if old_search_settings:
embedding_model = EmbeddingModel.from_db_model(
search_settings=current_search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
)
# only warm up if search settings were changed
warm_up_bi_encoder(
embedding_model=embedding_model,
)
embedding_model = EmbeddingModel.from_db_model(
search_settings=current_search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
)
warm_up_bi_encoder(
embedding_model=embedding_model,
)
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
@@ -125,7 +126,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
cc_pair_ids.append(cc_pair_entry.id)
for cc_pair_id in cc_pair_ids:
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
@@ -138,10 +138,10 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
search_settings.append(secondary_search_settings)
for search_settings_instance in search_settings:
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
rci = RedisConnectorIndexing(
cc_pair_id, search_settings_instance.id
)
if redis_connector_index.fenced:
if r.exists(rci.fence_key):
continue
cc_pair = get_connector_credential_pair_from_id(
@@ -175,9 +175,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
)
if attempt_id:
task_logger.info(
f"Indexing queued: index_attempt={attempt_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id} "
f"Indexing queued: cc_pair={cc_pair.id} index_attempt={attempt_id}"
)
tasks_created += 1
except SoftTimeLimitExceeded:
@@ -306,15 +304,15 @@ def try_creating_indexing_task(
return None
try:
redis_connector = RedisConnector(tenant_id, cc_pair.id)
redis_connector_index = redis_connector.new_index(search_settings.id)
rci = RedisConnectorIndexing(cc_pair.id, search_settings.id)
# skip if already indexing
if redis_connector_index.fenced:
if r.exists(rci.fence_key):
return None
# skip indexing if the cc_pair is deleting
if redis_connector.delete.fenced:
rcd = RedisConnectorDeletion(cc_pair.id)
if r.exists(rcd.fence_key):
return None
db_session.refresh(cc_pair)
@@ -322,17 +320,19 @@ def try_creating_indexing_task(
return None
# add a long running generator task to the queue
redis_connector_index.generator_clear()
r.delete(rci.generator_complete_key)
r.delete(rci.taskset_key)
custom_task_id = f"{rci.generator_task_id_prefix}_{uuid4()}"
# set a basic fence to start
payload = RedisConnectorIndexingFenceData(
fence_value = RedisConnectorIndexingFenceData(
index_attempt_id=None,
started=None,
submitted=datetime.now(timezone.utc),
celery_task_id=None,
)
redis_connector_index.set_fence(payload)
r.set(rci.fence_key, fence_value.model_dump_json())
# create the index attempt for tracking purposes
# code elsewhere checks for index attempts without an associated redis key
@@ -345,8 +345,6 @@ def try_creating_indexing_task(
db_session=db_session,
)
custom_task_id = redis_connector_index.generate_generator_task_id()
result = celery_app.send_task(
"connector_indexing_proxy_task",
kwargs=dict(
@@ -363,12 +361,11 @@ def try_creating_indexing_task(
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
# now fill out the fence with the rest of the data
payload.index_attempt_id = index_attempt_id
payload.celery_task_id = result.id
redis_connector_index.set_fence(payload)
fence_value.index_attempt_id = index_attempt_id
fence_value.celery_task_id = result.id
r.set(rci.fence_key, fence_value.model_dump_json())
except Exception:
redis_connector_index.set_fence(payload)
r.delete(rci.fence_key)
task_logger.exception(
f"Unexpected exception: "
f"tenant={tenant_id} "
@@ -391,12 +388,7 @@ def connector_indexing_proxy_task(
tenant_id: str | None,
) -> None:
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
task_logger.info(
f"Indexing proxy - starting: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
client = SimpleJobClient()
job = client.submit(
@@ -410,56 +402,29 @@ def connector_indexing_proxy_task(
)
if not job:
task_logger.info(
f"Indexing proxy - spawn failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return
task_logger.info(
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
while True:
sleep(10)
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
# do nothing for ongoing jobs that haven't been stopped
if not job.done():
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
# do nothing for ongoing jobs that haven't been stopped
if not job.done():
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
if job.status == "error":
task_logger.error(
f"Indexing proxy - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"error={job.exception()}"
)
if job.status == "error":
logger.error(job.exception())
job.release()
break
job.release()
break
task_logger.info(
f"Indexing proxy - finished: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return
@@ -481,97 +446,78 @@ def connector_indexing_task(
Returns None if the task did not run (possibly due to a conflict).
Otherwise, returns an int >= 0 representing the number of indexed docs.
NOTE: if an exception is raised out of this task, the primary worker will detect
that the task transitioned to a "READY" state but the generator_complete_key doesn't exist.
This will cause the primary worker to abort the indexing attempt and clean up.
"""
# Since connector_indexing_proxy_task spawns a new process using this function as
# the entrypoint, we init Sentry here.
if SENTRY_DSN:
sentry_sdk.init(
dsn=SENTRY_DSN,
traces_sample_rate=0.1,
)
logger.info("Sentry initialized")
else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
logger.info(
f"Indexing spawned task starting: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
attempt_found = False
n_final_progress: int | None = None
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
attempt = None
n_final_progress = 0
r = get_redis_client(tenant_id=tenant_id)
if redis_connector.delete.fenced:
rcd = RedisConnectorDeletion(cc_pair_id)
if r.exists(rcd.fence_key):
raise RuntimeError(
f"Indexing will not start because connector deletion is in progress: "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.delete.fence_key}"
f"fence={rcd.fence_key}"
)
if redis_connector.stop.fenced:
rcs = RedisConnectorStop(cc_pair_id)
if r.exists(rcs.fence_key):
raise RuntimeError(
f"Indexing will not start because a connector stop signal was detected: "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.stop.fence_key}"
f"fence={rcs.fence_key}"
)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
while True:
# wait for the fence to come up
if not redis_connector_index.fenced:
# read related data and evaluate/print task progress
fence_value = cast(bytes, r.get(rci.fence_key))
if fence_value is None:
raise ValueError(
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
f"connector_indexing_task: fence_value not found: fence={rci.fence_key}"
)
payload = redis_connector_index.payload
if not payload:
raise ValueError("connector_indexing_task: payload invalid or not found")
try:
fence_json = fence_value.decode("utf-8")
fence_data = RedisConnectorIndexingFenceData.model_validate_json(
cast(str, fence_json)
)
except ValueError:
task_logger.exception(
f"connector_indexing_task: fence_data not decodeable: fence={rci.fence_key}"
)
raise
if payload.index_attempt_id is None or payload.celery_task_id is None:
logger.info(
f"connector_indexing_task - Waiting for fence: fence={redis_connector_index.fence_key}"
if fence_data.index_attempt_id is None or fence_data.celery_task_id is None:
task_logger.info(
f"connector_indexing_task - Waiting for fence: fence={rci.fence_key}"
)
sleep(1)
continue
if payload.index_attempt_id != index_attempt_id:
raise ValueError(
f"connector_indexing_task - id mismatch. Task may be left over from previous run.: "
f"task_index_attempt={index_attempt_id} "
f"payload_index_attempt={payload.index_attempt_id}"
)
logger.info(
f"connector_indexing_task - Fence found, continuing...: fence={redis_connector_index.fence_key}"
task_logger.info(
f"connector_indexing_task - Fence found, continuing...: fence={rci.fence_key}"
)
break
lock = r.lock(
redis_connector_index.generator_lock_key,
rci.generator_lock_key,
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking=False)
if not acquired:
logger.warning(
task_logger.warning(
f"Indexing task already running, exiting...: "
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
)
# r.set(rci.generator_complete_key, HTTPStatus.CONFLICT.value)
return None
payload.started = datetime.now(timezone.utc)
redis_connector_index.set_fence(payload)
fence_data.started = datetime.now(timezone.utc)
r.set(rci.fence_key, fence_data.model_dump_json())
try:
with get_session_with_tenant(tenant_id) as db_session:
@@ -580,7 +526,6 @@ def connector_indexing_task(
raise ValueError(
f"Index attempt not found: index_attempt={index_attempt_id}"
)
attempt_found = True
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
@@ -600,52 +545,43 @@ def connector_indexing_task(
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}"
)
# define a callback class
callback = RunIndexingCallback(
redis_connector.stop.fence_key,
redis_connector_index.generator_progress_key,
lock,
r,
)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
logger.info(
f"Indexing spawned task running entrypoint: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
# define a callback class
callback = RunIndexingCallback(
rcs.fence_key, rci.generator_progress_key, lock, r
)
run_indexing_entrypoint(
index_attempt_id,
tenant_id,
cc_pair_id,
is_ee,
callback=callback,
)
run_indexing_entrypoint(
index_attempt_id,
tenant_id,
cc_pair_id,
is_ee,
callback=callback,
)
# get back the total number of indexed docs and return it
n_final_progress = redis_connector_index.get_progress()
redis_connector_index.set_generator_complete(HTTPStatus.OK.value)
# get back the total number of indexed docs and return it
generator_progress_value = r.get(rci.generator_progress_key)
if generator_progress_value is not None:
try:
n_final_progress = int(cast(int, generator_progress_value))
except ValueError:
pass
r.set(rci.generator_complete_key, HTTPStatus.OK.value)
except Exception as e:
logger.exception(
f"Indexing spawned task failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
if attempt_found:
task_logger.exception(f"Indexing failed: cc_pair={cc_pair_id}")
if attempt:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_failed(index_attempt_id, db_session, failure_reason=str(e))
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
r.delete(rci.generator_lock_key)
r.delete(rci.generator_progress_key)
r.delete(rci.taskset_key)
r.delete(rci.fence_key)
raise e
finally:
if lock.owned():
lock.release()
logger.info(
f"Indexing spawned task finished: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return n_final_progress

View File

@@ -11,6 +11,9 @@ from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
@@ -30,43 +33,12 @@ from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import pruning_ctx
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if pruning is due."""
# skip pruning if no prune frequency is set
# pruning can still be forced via the API which will run a pruning task directly
if not cc_pair.connector.prune_freq:
return False
# skip pruning if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
# skip pruning if the next scheduled prune time hasn't been reached yet
last_pruned = cc_pair.last_pruned
if not last_pruned:
if not cc_pair.last_successful_index_time:
# if we've never indexed, we can't prune
return False
# if never pruned, use the last time the connector indexed successfully
last_pruned = cc_pair.last_successful_index_time
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
if datetime.now(timezone.utc) < next_prune:
return False
return True
@shared_task(
name="check_for_pruning",
soft_time_limit=JOB_TIMEOUT,
@@ -98,7 +70,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
if not cc_pair:
continue
if not _is_pruning_due(cc_pair):
if not is_pruning_due(cc_pair, db_session, r):
continue
tasks_created = try_creating_prune_generator_task(
@@ -119,6 +91,47 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
lock_beat.release()
def is_pruning_due(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
) -> bool:
"""Returns an int if pruning is triggered.
The int represents the number of prune tasks generated (in this case, only one
because the task is a long running generator task.)
Returns None if no pruning is triggered (due to not being needed or
other reasons such as simultaneous pruning restrictions.
Checks for scheduling related conditions, then delegates the rest of the checks to
try_creating_prune_generator_task.
"""
# skip pruning if no prune frequency is set
# pruning can still be forced via the API which will run a pruning task directly
if not cc_pair.connector.prune_freq:
return False
# skip pruning if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
# skip pruning if the next scheduled prune time hasn't been reached yet
last_pruned = cc_pair.last_pruned
if not last_pruned:
if not cc_pair.last_successful_index_time:
# if we've never indexed, we can't prune
return False
# if never pruned, use the last time the connector indexed successfully
last_pruned = cc_pair.last_successful_index_time
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
if datetime.now(timezone.utc) < next_prune:
return False
return True
def try_creating_prune_generator_task(
celery_app: Celery,
cc_pair: ConnectorCredentialPair,
@@ -133,11 +146,8 @@ def try_creating_prune_generator_task(
is used to trigger prunes immediately, e.g. via the web ui.
"""
redis_connector = RedisConnector(tenant_id, cc_pair.id)
if not ALLOW_SIMULTANEOUS_PRUNING:
count = redis_connector.prune.get_active_task_count()
if count > 0:
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
return None
LOCK_TIMEOUT = 30
@@ -154,16 +164,15 @@ def try_creating_prune_generator_task(
return None
try:
rcp = RedisConnectorPruning(cc_pair.id)
# skip pruning if already pruning
if redis_connector.prune.fenced:
if r.exists(rcp.fence_key):
return None
# skip pruning if the cc_pair is deleting
if redis_connector.delete.fenced:
return None
# skip pruning if doc permissions sync is running
if redis_connector.permissions.fenced:
rcd = RedisConnectorDeletion(cc_pair.id)
if r.exists(rcd.fence_key):
return None
db_session.refresh(cc_pair)
@@ -171,10 +180,10 @@ def try_creating_prune_generator_task(
return None
# add a long running generator task to the queue
redis_connector.prune.generator_clear()
redis_connector.prune.taskset_clear()
r.delete(rcp.generator_complete_key)
r.delete(rcp.taskset_key)
custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}"
custom_task_id = f"{rcp.generator_task_id_prefix}_{uuid4()}"
celery_app.send_task(
"connector_pruning_generator_task",
@@ -190,7 +199,7 @@ def try_creating_prune_generator_task(
)
# set this only after all tasks have been added
redis_connector.prune.set_fence(True)
r.set(rcp.fence_key, 1)
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}")
return None
@@ -220,17 +229,12 @@ def connector_pruning_generator_task(
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
pruning_ctx_dict = pruning_ctx.get()
pruning_ctx_dict["cc_pair_id"] = cc_pair_id
pruning_ctx_dict["request_id"] = self.request.id
pruning_ctx.set(pruning_ctx_dict)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
rcp = RedisConnectorPruning(cc_pair_id)
lock = r.lock(
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{rcp._id}",
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
)
@@ -263,11 +267,10 @@ def connector_pruning_generator_task(
cc_pair.credential,
)
rcs = RedisConnectorStop(cc_pair_id)
callback = RunIndexingCallback(
redis_connector.stop.fence_key,
redis_connector.prune.generator_progress_key,
lock,
r,
rcs.fence_key, rcp.generator_progress_key, lock, r
)
# a list of docs in the source
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
@@ -294,29 +297,31 @@ def connector_pruning_generator_task(
f"doc_source={cc_pair.connector.source}"
)
rcp.documents_to_prune = set(doc_ids_to_remove)
task_logger.info(
f"RedisConnector.prune.generate_tasks starting. cc_pair={cc_pair_id}"
f"RedisConnectorPruning.generate_tasks starting. cc_pair={cc_pair.id}"
)
tasks_generated = redis_connector.prune.generate_tasks(
set(doc_ids_to_remove), self.app, db_session, None
tasks_generated = rcp.generate_tasks(
self.app, db_session, r, None, tenant_id
)
if tasks_generated is None:
return None
task_logger.info(
f"RedisConnector.prune.generate_tasks finished. "
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
f"RedisConnectorPruning.generate_tasks finished. "
f"cc_pair={cc_pair.id} tasks_generated={tasks_generated}"
)
redis_connector.prune.generator_complete = tasks_generated
r.set(rcp.generator_complete_key, tasks_generated)
except Exception as e:
task_logger.exception(
f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
)
redis_connector.prune.generator_clear()
redis_connector.prune.taskset_clear()
redis_connector.prune.set_fence(False)
r.delete(rcp.generator_progress_key)
r.delete(rcp.taskset_key)
r.delete(rcp.fence_key)
raise e
finally:
if lock.owned():

View File

@@ -0,0 +1,8 @@
from datetime import datetime
from pydantic import BaseModel
class RedisConnectorDeletionFenceData(BaseModel):
num_tasks: int | None
submitted: datetime

View File

@@ -0,0 +1,10 @@
from datetime import datetime
from pydantic import BaseModel
class RedisConnectorIndexingFenceData(BaseModel):
index_attempt_id: int | None
started: datetime | None
submitted: datetime
celery_task_id: str | None

View File

@@ -59,7 +59,7 @@ def document_by_cc_pair_cleanup_task(
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
task_logger.debug(f"Task start: tenant={tenant_id} doc={document_id}")
task_logger.info(f"tenant={tenant_id} doc={document_id}")
try:
with get_session_with_tenant(tenant_id) as db_session:
@@ -141,9 +141,7 @@ def document_by_cc_pair_cleanup_task(
return False
except Exception as ex:
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
task_logger.info(f"Retry failed: {ex.last_attempt.attempt_number}")
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
@@ -173,8 +171,8 @@ def document_by_cc_pair_cleanup_task(
else:
# This is the last attempt! mark the document as dirty in the db so that it
# eventually gets fixed out of band via stale document reconciliation
task_logger.warning(
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
task_logger.info(
f"Max retries reached. Marking doc as dirty for reconciliation: "
f"tenant={tenant_id} doc={document_id}"
)
with get_session_with_tenant(tenant_id):

View File

@@ -19,6 +19,18 @@ from tenacity import RetryError
from danswer.access.access import get_access_for_document
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import celery_get_queue_length
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
RedisConnectorDeletionFenceData,
)
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
RedisConnectorIndexingFenceData,
)
from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from danswer.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from danswer.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
@@ -27,7 +39,6 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector import mark_cc_pair_as_permissions_synced
from danswer.db.connector import mark_ccpair_as_pruned
from danswer.db.connector_credential_pair import add_deletion_failure_message
from danswer.db.connector_credential_pair import (
@@ -56,18 +67,7 @@ from danswer.db.models import IndexAttempt
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncData,
)
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_document_set import RedisDocumentSet
from danswer.redis.redis_pool import get_redis_client
from danswer.redis.redis_usergroup import RedisUserGroup
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import (
@@ -192,7 +192,7 @@ def try_generate_stale_document_sync_tasks(
total_tasks_generated = 0
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
rc = RedisConnectorCredentialPair(cc_pair.id)
tasks_generated = rc.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
@@ -228,10 +228,10 @@ def try_generate_document_set_sync_tasks(
) -> int | None:
lock_beat.reacquire()
rds = RedisDocumentSet(tenant_id, document_set_id)
rds = RedisDocumentSet(document_set_id)
# don't generate document set sync tasks if tasks are still pending
if rds.fenced:
if r.exists(rds.fence_key):
return None
# don't generate sync tasks if we're up to date
@@ -269,7 +269,7 @@ def try_generate_document_set_sync_tasks(
)
# set this only after all tasks have been added
rds.set_fence(tasks_generated)
r.set(rds.fence_key, tasks_generated)
return tasks_generated
@@ -283,9 +283,10 @@ def try_generate_user_group_sync_tasks(
) -> int | None:
lock_beat.reacquire()
rug = RedisUserGroup(tenant_id, usergroup_id)
if rug.fenced:
# don't generate sync tasks if tasks are still pending
rug = RedisUserGroup(usergroup_id)
# don't generate sync tasks if tasks are still pending
if r.exists(rug.fence_key):
return None
# race condition with the monitor/cleanup function if we use a cached result!
@@ -325,7 +326,7 @@ def try_generate_user_group_sync_tasks(
)
# set this only after all tasks have been added
rug.set_fence(tasks_generated)
r.set(rug.fence_key, tasks_generated)
return tasks_generated
@@ -351,7 +352,7 @@ def monitor_connector_taskset(r: Redis) -> None:
def monitor_document_set_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
@@ -361,12 +362,16 @@ def monitor_document_set_taskset(
document_set_id = int(document_set_id_str)
rds = RedisDocumentSet(tenant_id, document_set_id)
if not rds.fenced:
rds = RedisDocumentSet(document_set_id)
fence_value = r.get(rds.fence_key)
if fence_value is None:
return
initial_count = rds.payload
if initial_count is None:
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rds.taskset_key))
@@ -394,38 +399,48 @@ def monitor_document_set_taskset(
f"Successfully synced document set: document_set={document_set_id}"
)
rds.reset()
r.delete(rds.taskset_key)
r.delete(rds.fence_key)
def monitor_connector_deletion_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis
key_bytes: bytes, r: Redis, tenant_id: str | None
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
cc_pair_id_str = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
rcd = RedisConnectorDeletion(cc_pair_id)
fence_data = redis_connector.delete.payload
if not fence_data:
task_logger.warning(
f"Connector deletion - fence payload invalid: cc_pair={cc_pair_id}"
# read related data and evaluate/print task progress
fence_value = cast(bytes, r.get(rcd.fence_key))
if fence_value is None:
return
try:
fence_json = fence_value.decode("utf-8")
fence_data = RedisConnectorDeletionFenceData.model_validate_json(
cast(str, fence_json)
)
return
except ValueError:
task_logger.exception(
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
)
raise
# the fence is setting up but isn't ready yet
if fence_data.num_tasks is None:
# the fence is setting up but isn't ready yet
return
remaining = redis_connector.delete.get_remaining()
count = cast(int, r.scard(rcd.taskset_key))
task_logger.info(
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={fence_data.num_tasks}"
)
if remaining > 0:
if count > 0:
return
with get_session_with_tenant(tenant_id) as db_session:
@@ -509,15 +524,15 @@ def monitor_connector_deletion_taskset(
f"docs_deleted={fence_data.num_tasks}"
)
redis_connector.delete.taskset_clear()
redis_connector.delete.set_fence(None)
r.delete(rcd.taskset_key)
r.delete(rcd.fence_key)
def monitor_ccpair_pruning_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
cc_pair_id_str = RedisConnectorPruning.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
@@ -526,78 +541,46 @@ def monitor_ccpair_pruning_taskset(
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if not redis_connector.prune.fenced:
rcp = RedisConnectorPruning(cc_pair_id)
fence_value = r.get(rcp.fence_key)
if fence_value is None:
return
initial = redis_connector.prune.generator_complete
if initial is None:
generator_value = r.get(rcp.generator_complete_key)
if generator_value is None:
return
remaining = redis_connector.prune.get_remaining()
try:
initial_count = int(cast(int, generator_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rcp.taskset_key))
task_logger.info(
f"Connector pruning progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
f"Connector pruning progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
)
if remaining > 0:
if count > 0:
return
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
task_logger.info(
f"Successfully pruned connector credential pair. cc_pair={cc_pair_id}"
f"Successfully pruned connector credential pair. cc_pair_id={cc_pair_id}"
)
redis_connector.prune.taskset_clear()
redis_connector.prune.generator_clear()
redis_connector.prune.set_fence(False)
def monitor_ccpair_permissions_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if not redis_connector.permissions.fenced:
return
initial = redis_connector.permissions.generator_complete
if initial is None:
return
remaining = redis_connector.permissions.get_remaining()
task_logger.info(
f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
)
if remaining > 0:
return
payload: RedisConnectorPermissionSyncData | None = (
redis_connector.permissions.payload
)
start_time: datetime | None = payload.started if payload else None
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
redis_connector.permissions.taskset_clear()
redis_connector.permissions.generator_clear()
redis_connector.permissions.set_fence(None)
r.delete(rcp.taskset_key)
r.delete(rcp.generator_progress_key)
r.delete(rcp.generator_complete_key)
r.delete(rcp.fence_key)
def monitor_ccpair_indexing_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
composite_id = RedisConnectorIndexing.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"monitor_ccpair_indexing_taskset: could not parse composite_id from {fence_key}"
@@ -612,37 +595,53 @@ def monitor_ccpair_indexing_taskset(
cc_pair_id = int(parts[0])
search_settings_id = int(parts[1])
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
if not redis_connector_index.fenced:
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
# read related data and evaluate/print task progress
fence_value = cast(bytes, r.get(rci.fence_key))
if fence_value is None:
return
payload = redis_connector_index.payload
if not payload:
return
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
progress = redis_connector_index.get_progress()
if progress is not None:
task_logger.info(
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"progress={progress} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
try:
fence_json = fence_value.decode("utf-8")
fence_data = RedisConnectorIndexingFenceData.model_validate_json(
cast(str, fence_json)
)
except ValueError:
task_logger.exception(
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
)
raise
if payload.index_attempt_id is None or payload.celery_task_id is None:
elapsed_submitted = datetime.now(timezone.utc) - fence_data.submitted
generator_progress_value = r.get(rci.generator_progress_key)
if generator_progress_value is not None:
try:
progress_count = int(cast(int, generator_progress_value))
task_logger.info(
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"progress={progress_count} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
except ValueError:
task_logger.error(
"monitor_ccpair_indexing_taskset: generator_progress_value is not an integer."
)
if fence_data.index_attempt_id is None or fence_data.celery_task_id is None:
# the task is still setting up
return
# Read result state BEFORE generator_complete_key to avoid a race condition
# never use any blocking methods on the result from inside a task!
result: AsyncResult = AsyncResult(payload.celery_task_id)
result: AsyncResult = AsyncResult(fence_data.celery_task_id)
result_state = result.state
status_int = redis_connector_index.get_completion()
if status_int is None:
generator_complete_value = r.get(rci.generator_complete_key)
if generator_complete_value is None:
if result_state in READY_STATES:
# IF the task state is READY, THEN generator_complete should be set
# if it isn't, then the worker crashed
@@ -653,18 +652,30 @@ def monitor_ccpair_indexing_taskset(
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
index_attempt = get_index_attempt(db_session, fence_data.index_attempt_id)
if index_attempt:
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
index_attempt=index_attempt,
db_session=db_session,
failure_reason="Connector indexing aborted or exceptioned.",
)
redis_connector_index.reset()
r.delete(rci.generator_lock_key)
r.delete(rci.taskset_key)
r.delete(rci.generator_progress_key)
r.delete(rci.generator_complete_key)
r.delete(rci.fence_key)
return
status_enum = HTTPStatus(status_int)
status_enum = HTTPStatus.INTERNAL_SERVER_ERROR
try:
status_value = int(cast(int, generator_complete_value))
status_enum = HTTPStatus(status_value)
except ValueError:
task_logger.error(
f"monitor_ccpair_indexing_taskset: "
f"generator_complete_value=f{generator_complete_value} could not be parsed."
)
task_logger.info(
f"Connector indexing finished: cc_pair_id={cc_pair_id} "
@@ -673,7 +684,11 @@ def monitor_ccpair_indexing_taskset(
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
redis_connector_index.reset()
r.delete(rci.generator_lock_key)
r.delete(rci.taskset_key)
r.delete(rci.generator_progress_key)
r.delete(rci.generator_complete_key)
r.delete(rci.fence_key)
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
@@ -685,7 +700,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
do anything too expensive in this function!
Returns True if the task actually did work, False if it exited early to prevent overlap
Returns True if the task actually did work, False
"""
r = get_redis_client(tenant_id=tenant_id)
@@ -714,17 +729,13 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
n_pruning = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_PRUNING, r_celery
)
n_permissions_sync = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
)
task_logger.info(
f"Queue lengths: celery={n_celery} "
f"indexing={n_indexing} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning} "
f"permissions_sync={n_permissions_sync} "
f"pruning={n_pruning}"
)
# do some cleanup before clearing fences
@@ -738,37 +749,29 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for attempt in attempts:
for a in attempts:
# if attempts exist in the db but we don't detect them in redis, mark them as failed
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
rci = RedisConnectorIndexing(
a.connector_credential_pair_id, a.search_settings_id
)
if not r.exists(fence_key):
failure_reason = (
f"Unknown index attempt. Might be left over from a process restart: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
task_logger.warning(failure_reason)
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
failure_reason = f"Unknown index attempt {a.id}. Might be left over from a process restart."
if not r.exists(rci.fence_key):
mark_attempt_failed(a, db_session, failure_reason=failure_reason)
lock_beat.reacquire()
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"):
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
lock_beat.reacquire()
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
monitor_connector_deletion_taskset(key_bytes, r, tenant_id)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
monitor_document_set_taskset(key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
@@ -779,25 +782,19 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
noop_fallback,
)
with get_session_with_tenant(tenant_id) as db_session:
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
monitor_usergroup_taskset(key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
for key_bytes in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
monitor_ccpair_indexing_taskset(key_bytes, r, db_session)
# uncomment for debugging if needed
# r_celery = celery_app.broker_connection().channel().client
@@ -869,9 +866,7 @@ def vespa_metadata_sync_task(
)
except Exception as ex:
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
task_logger.warning(f"Retry failed: {ex.last_attempt.attempt_number}")
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()

View File

@@ -1,6 +1,8 @@
"""Factory stub for running celery worker / celery beat."""
from danswer.background.celery.apps.beat import celery_app
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app = celery_app
app = fetch_versioned_implementation(
"danswer.background.celery.apps.beat", "celery_app"
)

View File

@@ -11,8 +11,7 @@ from typing import Any
from typing import Literal
from typing import Optional
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.db.engine import get_sqlalchemy_engine
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -29,26 +28,16 @@ JobStatusType = (
def _initializer(
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
) -> Any:
"""Initialize the child process with a fresh SQLAlchemy Engine.
"""Ensure the parent proc's database connections are not touched
in the new connection pool
Based on SQLAlchemy's recommendations to handle multiprocessing:
Based on the recommended approach in the SQLAlchemy docs found:
https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
"""
if kwargs is None:
kwargs = {}
logger.info("Initializing spawned worker child process.")
# Reset the engine in the child process
SqlEngine.reset_engine()
# Optionally set a custom app name for database logging purposes
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
# Initialize a new engine with desired parameters
SqlEngine.init_engine(pool_size=4, max_overflow=12, pool_recycle=60)
# Proceed with executing the target function
get_sqlalchemy_engine().dispose(close=False)
return func(*args, **kwargs)

View File

@@ -33,8 +33,8 @@ from danswer.document_index.factory import get_default_document_index
from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import IndexAttemptSingleton
from danswer.utils.logger import setup_logger
from danswer.utils.logger import TaskAttemptSingleton
from danswer.utils.variable_functionality import global_version
logger = setup_logger()
@@ -118,13 +118,7 @@ def _run_indexing(
"""
start_time = time.time()
if index_attempt.search_settings is None:
raise ValueError(
"Search settings must be set for indexing. This should not be possible."
)
search_settings = index_attempt.search_settings
index_name = search_settings.index_name
# Only update cc-pair status for primary index jobs
@@ -337,7 +331,7 @@ def _run_indexing(
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
index_attempt.id,
index_attempt,
db_session,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
@@ -372,7 +366,7 @@ def _run_indexing(
and index_attempt_md.num_exceptions >= batch_num
):
mark_attempt_failed(
index_attempt.id,
index_attempt,
db_session,
failure_reason="All batches exceptioned.",
)
@@ -427,7 +421,7 @@ def run_indexing_entrypoint(
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
TaskAttemptSingleton.set_cc_and_index_id(
IndexAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
)
with get_session_with_tenant(tenant_id) as db_session:

View File

@@ -14,6 +14,15 @@ from danswer.db.tasks import mark_task_start
from danswer.db.tasks import register_task
def name_cc_prune_task(
connector_id: int | None = None, credential_id: int | None = None
) -> str:
task_name = f"prune_connector_credential_pair_{connector_id}_{credential_id}"
if not connector_id or not credential_id:
task_name = "prune_connector_credential_pair"
return task_name
T = TypeVar("T", bound=Callable)

View File

@@ -10,7 +10,7 @@ from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import RetrievalDocs
from danswer.search.models import SearchResponse
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
from danswer.tools.custom.base_tool_types import ToolResultType
class LlmDoc(BaseModel):
@@ -156,7 +156,7 @@ class QAResponse(SearchResponse, DanswerAnswer):
error_msg: str | None = None
class FileChatDisplay(BaseModel):
class ImageGenerationDisplay(BaseModel):
file_ids: list[str]
@@ -170,7 +170,7 @@ AnswerQuestionPossibleReturn = (
| DanswerQuotes
| CitationInfo
| DanswerContexts
| FileChatDisplay
| ImageGenerationDisplay
| CustomToolResponse
| StreamingError
| StreamStopInfo

View File

@@ -42,14 +42,18 @@ personas:
display_priority: 1
is_visible: true
starter_messages:
- name: "Give me an overview of what's here"
message: "Sample some documents and tell me what you find."
- name: "Use AI to solve a work related problem"
message: "Ask me what problem I would like to solve, then search the knowledge base to help me find a solution."
- name: "Find updates on a topic of interest"
message: "Once I provide a topic, retrieve related documents and tell me when there was last activity on the topic if available."
- name: "Surface contradictions"
message: "Have me choose a subject. Once I have provided it, check against the knowledge base and point out any inconsistencies. For all your following responses, focus on identifying contradictions."
- name: "General Information"
description: "Ask about available information"
message: "Hello! I'm interested in learning more about the information available here. Could you give me an overview of the types of data or documents that might be accessible?"
- name: "Specific Topic Search"
description: "Search for specific information"
message: "Hi! I'd like to learn more about a specific topic. Could you help me find relevant documents and information?"
- name: "Recent Updates"
description: "Inquire about latest additions"
message: "Hello! I'm curious about any recent updates or additions to the knowledge base. Can you tell me what new information has been added lately?"
- name: "Cross-referencing Information"
description: "Connect information from different sources"
message: "Hi! I'm working on a project that requires connecting information from multiple sources. How can I effectively cross-reference data across different documents or categories?"
- id: 1
name: "General"
@@ -67,14 +71,18 @@ personas:
display_priority: 0
is_visible: true
starter_messages:
- name: "Summarize a document"
message: "If I have provided a document please summarize it for me. If not, please ask me to upload a document either by dragging it into the input bar or clicking the +file icon."
- name: "Help me with coding"
message: 'Write me a "Hello World" script in 5 random languages to show off the functionality.'
- name: "Draft a professional email"
message: "Help me craft a professional email. Let's establish the context and the anticipated outcomes of the email before proposing a draft."
- name: "Learn something new"
message: "What is the difference between a Gantt chart, a Burndown chart and a Kanban board?"
- name: "Open Discussion"
description: "Start an open-ended conversation"
message: "Hi! Can you help me write a professional email?"
- name: "Problem Solving"
description: "Get help with a challenge"
message: "Hello! I need help managing my daily tasks better. Do you have any simple tips?"
- name: "Learn Something New"
description: "Explore a new topic"
message: "Hi! Could you explain what project management is in simple terms?"
- name: "Creative Brainstorming"
description: "Generate creative ideas"
message: "Hello! I need to brainstorm some team building activities. Do you have any fun suggestions?"
- id: 2
name: "Paraphrase"
@@ -93,12 +101,16 @@ personas:
is_visible: false
starter_messages:
- name: "Document Search"
description: "Find exact information"
message: "Hi! Could you help me find information about our team structure and reporting lines from our internal documents?"
- name: "Process Verification"
description: "Find exact quotes"
message: "Hello! I need to understand our project approval process. Could you find the exact steps from our documentation?"
- name: "Technical Documentation"
description: "Search technical details"
message: "Hi there! I'm looking for information about our deployment procedures. Can you find the specific steps from our technical guides?"
- name: "Policy Reference"
description: "Check official policies"
message: "Hello! Could you help me find our official guidelines about client communication? I need the exact wording from our documentation."
- id: 3
@@ -118,11 +130,15 @@ personas:
display_priority: 3
is_visible: true
starter_messages:
- name: "Create visuals for a presentation"
message: "Generate someone presenting a graph which clearly demonstrates an upwards trajectory."
- name: "Find inspiration for a marketing campaign"
message: "Generate an image of two happy individuals sipping on a soda drink in a glass bottle."
- name: "Visualize a product design"
message: "I want to add a search bar to my Iphone app. Generate me generic examples of how other apps implement this."
- name: "Generate a humorous image response"
message: "My teammate just made a silly mistake and I want to respond with a facepalm. Can you generate me one?"
- name: "Landscape"
description: "Generate a landscape image"
message: "Create an image of a serene mountain lake at sunset, with snow-capped peaks reflected in the calm water and a small wooden cabin on the shore."
- name: "Character"
description: "Generate a character image"
message: "Generate an image of a futuristic robot with glowing blue eyes, sleek metallic body, and intricate circuitry visible through transparent panels on its chest and arms."
- name: "Abstract"
description: "Create an abstract image"
message: "Create an abstract image representing the concept of time, using swirling clock hands, fragmented hourglasses, and streaks of light to convey the passage of moments and eras."
- name: "Urban Scene"
description: "Generate an urban landscape"
message: "Generate an image of a bustling futuristic cityscape at night, with towering skyscrapers, flying vehicles, holographic advertisements, and a mix of neon and bioluminescent lighting."

View File

@@ -11,18 +11,23 @@ from danswer.chat.models import AllCitations
from danswer.chat.models import CitationInfo
from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import FileChatDisplay
from danswer.chat.models import FinalUsedContextDocsResponse
from danswer.chat.models import ImageGenerationDisplay
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import MessageResponseIDInfo
from danswer.chat.models import MessageSpecificCitations
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.chat.models import StreamStopInfo
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.chat import attach_files_to_chat_message
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
@@ -35,6 +40,7 @@ from danswer.db.chat import reserve_message_id
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
@@ -54,13 +60,14 @@ from danswer.llm.answering.models import PromptConfig
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import litellm_exception_to_error_msg
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import InferenceSection
from danswer.search.models import RetrievalDetails
from danswer.search.retrieval.search_runner import inference_sections_from_ids
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
@@ -69,48 +76,36 @@ from danswer.search.utils import relevant_sections_to_indices
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from danswer.tools.custom.custom_tool import CustomToolCallSummary
from danswer.tools.force import ForceUseTool
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
from danswer.tools.tool_constructor import construct_tools
from danswer.tools.tool_constructor import CustomToolConfig
from danswer.tools.tool_constructor import ImageGenerationToolConfig
from danswer.tools.tool_constructor import InternetSearchToolConfig
from danswer.tools.tool_constructor import SearchToolConfig
from danswer.tools.tool_implementations.custom.custom_tool import (
CUSTOM_TOOL_RESPONSE_ID,
)
from danswer.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
from danswer.tools.tool_implementations.images.image_generation_tool import (
IMAGE_GENERATION_RESPONSE_ID,
)
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationResponse,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
from danswer.tools.images.image_generation_tool import ImageGenerationTool
from danswer.tools.internet_search.internet_search_tool import (
INTERNET_SEARCH_RESPONSE_ID,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
from danswer.tools.internet_search.internet_search_tool import (
internet_search_response_to_search_docs,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchResponse,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchTool,
)
from danswer.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from danswer.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary
from danswer.tools.tool_implementations.search.search_tool import SearchTool
from danswer.tools.tool_implementations.search.search_tool import (
SECTION_RELEVANCE_LIST_ID,
)
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.utils import compute_all_tool_tokens
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.headers import header_dict_to_header_list
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
@@ -261,11 +256,10 @@ ChatPacket = (
| DanswerAnswerPiece
| AllCitations
| CitationInfo
| FileChatDisplay
| ImageGenerationDisplay
| CustomToolResponse
| MessageSpecificCitations
| MessageResponseIDInfo
| StreamStopInfo
)
ChatPacketStream = Iterator[ChatPacket]
@@ -281,6 +275,7 @@ def stream_chat_message_objects(
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
# if specified, uses the last user message and does not create a new user message based
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
@@ -292,9 +287,6 @@ def stream_chat_message_objects(
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
use_existing_user_message = new_msg_req.use_existing_user_message
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
# Currently surrounding context is not supported for chat
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
new_msg_req.chunks_above = 0
@@ -416,20 +408,12 @@ def stream_chat_message_objects(
final_msg, history_msgs = create_chat_chain(
chat_session_id=chat_session_id, db_session=db_session
)
if existing_assistant_message_id is None:
if final_msg.message_type != MessageType.USER:
raise RuntimeError(
"The last message was not a user message. Cannot call "
"`stream_chat_message_objects` with `is_regenerate=True` "
"when the last message is not a user message."
)
else:
if final_msg.id != existing_assistant_message_id:
raise RuntimeError(
"The last message was not the existing assistant message. "
f"Final message id: {final_msg.id}, "
f"existing assistant message id: {existing_assistant_message_id}"
)
if final_msg.message_type != MessageType.USER:
raise RuntimeError(
"The last message was not a user message. Cannot call "
"`stream_chat_message_objects` with `is_regenerate=True` "
"when the last message is not a user message."
)
# Disable Query Rephrasing for the first message
# This leads to a better first response since the LLM rephrasing the question
@@ -500,19 +484,13 @@ def stream_chat_message_objects(
),
max_window_percentage=max_document_percentage,
)
# we don't need to reserve a message id if we're using an existing assistant message
reserved_message_id = (
final_msg.id
if existing_assistant_message_id is not None
else reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=user_message.id
if user_message is not None
else parent_message.id,
message_type=MessageType.ASSISTANT,
)
reserved_message_id = reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=user_message.id
if user_message is not None
else parent_message.id,
message_type=MessageType.ASSISTANT,
)
yield MessageResponseIDInfo(
user_message_id=user_message.id if user_message else None,
@@ -527,13 +505,7 @@ def stream_chat_message_objects(
partial_response = partial(
create_new_chat_message,
chat_session_id=chat_session_id,
# if we're using an existing assistant message, then this will just be an
# update operation, in which case the parent should be the parent of
# the latest. If we're creating a new assistant message, then the parent
# should be the latest message (latest user message)
parent_message=(
final_msg if existing_assistant_message_id is None else parent_message
),
parent_message=final_msg,
prompt_id=prompt_id,
overridden_model=overridden_model,
# message=,
@@ -545,7 +517,6 @@ def stream_chat_message_objects(
# reference_docs=,
db_session=db_session,
commit=False,
reserved_message_id=reserved_message_id,
)
if not final_msg.prompt:
@@ -561,53 +532,148 @@ def stream_chat_message_objects(
if not persona
else PromptConfig.from_model(persona.prompts[0])
)
answer_style_config = AnswerStyleConfig(
citation_config=CitationConfig(
all_docs_useful=selected_db_search_docs is not None
),
document_pruning_config=document_pruning_config,
structured_response_format=new_msg_req.structured_response_format,
)
tool_dict = construct_tools(
persona=persona,
prompt_config=prompt_config,
db_session=db_session,
user=user,
llm=llm,
fast_llm=fast_llm,
search_tool_config=SearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
retrieval_options=retrieval_options or RetrievalDetails(),
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
latest_query_files=latest_query_files,
),
internet_search_tool_config=InternetSearchToolConfig(
answer_style_config=answer_style_config,
),
image_generation_tool_config=ImageGenerationToolConfig(
additional_headers=litellm_additional_headers,
),
custom_tool_config=CustomToolConfig(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
additional_headers=custom_tool_additional_headers,
),
)
# find out what tools to use
search_tool: SearchTool | None = None
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
for db_tool_model in persona.tools:
# handle in-code tools specially
if db_tool_model.in_code_tool_id:
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
search_tool = SearchTool(
db_session=db_session,
user=user,
persona=persona,
retrieval_options=retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
evaluation_type=LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP,
)
tool_dict[db_tool_model.id] = [search_tool]
elif tool_cls.__name__ == ImageGenerationTool.__name__:
img_generation_llm_config: LLMConfig | None = None
if (
llm
and llm.config.api_key
and llm.config.model_provider == "openai"
):
img_generation_llm_config = LLMConfig(
model_provider=llm.config.model_provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=llm.config.api_key,
api_base=llm.config.api_base,
api_version=llm.config.api_version,
)
elif (
llm.config.model_provider == "azure"
and AZURE_DALLE_API_KEY is not None
):
img_generation_llm_config = LLMConfig(
model_provider="azure",
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
temperature=GEN_AI_TEMPERATURE,
api_key=AZURE_DALLE_API_KEY,
api_base=AZURE_DALLE_API_BASE,
api_version=AZURE_DALLE_API_VERSION,
)
else:
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
iter(
[
llm_provider
for llm_provider in llm_providers
if llm_provider.provider == "openai"
]
),
None,
)
if not openai_provider or not openai_provider.api_key:
raise ValueError(
"Image generation tool requires an OpenAI API key"
)
img_generation_llm_config = LLMConfig(
model_provider=openai_provider.provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=openai_provider.api_key,
api_base=openai_provider.api_base,
api_version=openai_provider.api_version,
)
tool_dict[db_tool_model.id] = [
ImageGenerationTool(
api_key=cast(str, img_generation_llm_config.api_key),
api_base=img_generation_llm_config.api_base,
api_version=img_generation_llm_config.api_version,
additional_headers=litellm_additional_headers,
model=img_generation_llm_config.model_name,
)
]
elif tool_cls.__name__ == InternetSearchTool.__name__:
bing_api_key = BING_API_KEY
if not bing_api_key:
raise ValueError(
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
)
tool_dict[db_tool_model.id] = [
InternetSearchTool(api_key=bing_api_key)
]
continue
# handle all custom tools
if db_tool_model.openapi_schema:
tool_dict[db_tool_model.id] = cast(
list[Tool],
build_custom_tools_from_openapi_schema_and_headers(
db_tool_model.openapi_schema,
dynamic_schema_info=DynamicSchemaInfo(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
),
custom_headers=(db_tool_model.custom_headers or [])
+ (
header_dict_to_header_list(
custom_tool_additional_headers or {}
)
),
),
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)
# factor in tool definition size when pruning
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
tools, llm_tokenizer
)
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
llm_provider, llm_model_name
)
# LLM prompt building, response capturing, etc.
answer = Answer(
is_connected=is_connected,
question=final_msg.message,
latest_query_files=latest_query_files,
answer_style_config=answer_style_config,
answer_style_config=AnswerStyleConfig(
citation_config=CitationConfig(
all_docs_useful=selected_db_search_docs is not None
),
document_pruning_config=document_pruning_config,
structured_response_format=new_msg_req.structured_response_format,
),
prompt_config=prompt_config,
llm=(
llm
@@ -675,6 +741,7 @@ def stream_chat_message_objects(
yield LLMRelevanceFilterResponse(
llm_selected_doc_indices=llm_indices
)
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(
final_context_docs=packet.response
@@ -692,7 +759,7 @@ def stream_chat_message_objects(
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
for file_id in file_ids
]
yield FileChatDisplay(
yield ImageGenerationDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
@@ -706,32 +773,11 @@ def stream_chat_message_objects(
yield qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
if (
custom_tool_response.response_type == "image"
or custom_tool_response.response_type == "csv"
):
file_ids = custom_tool_response.tool_result.file_ids
ai_message_files = [
FileDescriptor(
id=str(file_id),
type=ChatFileType.IMAGE
if custom_tool_response.response_type == "image"
else ChatFileType.CSV,
)
for file_id in file_ids
]
yield FileChatDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
else:
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
elif isinstance(packet, StreamStopInfo):
pass
else:
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
@@ -761,7 +807,6 @@ def stream_chat_message_objects(
# Post-LLM answer processing
try:
logger.debug("Post-LLM answer processing")
message_specific_citations: MessageSpecificCitations | None = None
if reference_db_search_docs:
message_specific_citations = _translate_citations(
@@ -777,6 +822,7 @@ def stream_chat_message_objects(
tool_name_to_tool_id[tool.name] = tool_id
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
message=answer.llm_answer,
rephrased_query=(
qa_docs_response.rephrased_query if qa_docs_response else None
@@ -784,21 +830,21 @@ def stream_chat_message_objects(
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=(
message_specific_citations.citation_map
if message_specific_citations
else None
),
citations=message_specific_citations.citation_map
if message_specific_citations
else None,
error=None,
tool_call=(
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
tool_calls=(
[
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
if tool_result
else None
else []
),
)
@@ -822,6 +868,7 @@ def stream_chat_message_objects(
def stream_chat_message(
new_msg_req: CreateChatMessageRequest,
user: User | None,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
@@ -831,6 +878,7 @@ def stream_chat_message(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
custom_tool_additional_headers=custom_tool_additional_headers,
is_connected=is_connected,

View File

@@ -9,19 +9,19 @@ prompts:
system: >
You are a question answering system that is constantly learning and improving.
The current date is DANSWER_DATETIME_REPLACEMENT.
You can process and comprehend vast amounts of text and utilize this knowledge to provide
grounded, accurate, and concise answers to diverse queries.
You always clearly communicate ANY UNCERTAINTY in your answer.
# Task Prompt (as shown in UI)
task: >
Answer my query based on the documents provided.
The documents may not all be relevant, ignore any documents that are not directly relevant
to the most recent user query.
I have not read or seen any of the documents and do not want to read them.
If there are no relevant documents, refer to the chat history and your internal knowledge.
# Inject a statement at the end of system prompt to inform the LLM of the current date/time
# If the DANSWER_DATETIME_REPLACEMENT is set, the date/time is inserted there instead
@@ -30,21 +30,21 @@ prompts:
# Prompts the LLM to include citations in the for [1], [2] etc.
# which get parsed to match the passed in sources
include_citations: true
- name: "ImageGeneration"
description: "Generates images from user descriptions!"
description: "Generates images based on user prompts!"
system: >
You are an AI image generation assistant. Your role is to create high-quality images based on user descriptions.
For appropriate requests, you will generate an image that matches the user's requirements.
For inappropriate or unsafe requests, you will politely decline and explain why the request cannot be fulfilled.
You aim to be helpful while maintaining appropriate content standards.
You are an advanced image generation system capable of creating diverse and detailed images.
You can interpret user prompts and generate high-quality, creative images that match their descriptions.
You always strive to create safe and appropriate content, avoiding any harmful or offensive imagery.
task: >
Based on the user's description, create a high-quality image that accurately reflects their request.
Pay close attention to the specified details, styles, and desired elements.
If the request is not appropriate or cannot be fulfilled, explain why and suggest alternatives.
Generate an image based on the user's description.
Provide a detailed description of the generated image, including key elements, colors, and composition.
If the request is not possible or appropriate, explain why and suggest alternatives.
datetime_aware: true
include_citations: false
@@ -64,13 +64,14 @@ prompts:
datetime_aware: true
include_citations: true
- name: "Summarize"
description: "Summarize relevant information from retrieved context!"
system: >
You are a text summarizing assistant that highlights the most important knowledge from the
context provided, prioritizing the information that relates to the user query.
The current date is DANSWER_DATETIME_REPLACEMENT.
You ARE NOT creative and always stick to the provided documents.
If there are no documents, refer to the conversation history.
@@ -83,6 +84,7 @@ prompts:
datetime_aware: true
include_citations: true
- name: "Paraphrase"
description: "Recites information from retrieved context! Least creative but most safe!"
system: >
@@ -90,10 +92,10 @@ prompts:
The current date is DANSWER_DATETIME_REPLACEMENT.
You only provide quotes that are EXACT substrings from provided documents!
If there are no documents provided,
simply tell the user that there are no documents to reference.
You NEVER generate new text or phrases outside of the citation.
DO NOT explain your responses, only provide the quotes and NOTHING ELSE.
task: >

View File

@@ -163,17 +163,6 @@ try:
except ValueError:
POSTGRES_POOL_RECYCLE = POSTGRES_POOL_RECYCLE_DEFAULT
# Experimental setting to control idle transactions
POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT = 0 # milliseconds
try:
POSTGRES_IDLE_SESSIONS_TIMEOUT = int(
os.environ.get(
"POSTGRES_IDLE_SESSIONS_TIMEOUT", POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT
)
)
except ValueError:
POSTGRES_IDLE_SESSIONS_TIMEOUT = POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT
REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true"
REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
@@ -262,6 +251,9 @@ ENABLED_CONNECTOR_TYPES = os.environ.get("ENABLED_CONNECTOR_TYPES") or ""
# for some connectors
ENABLE_EXPENSIVE_EXPERT_CALLS = False
GOOGLE_DRIVE_INCLUDE_SHARED = False
GOOGLE_DRIVE_FOLLOW_SHORTCUTS = False
GOOGLE_DRIVE_ONLY_ORG_PUBLIC = False
# TODO these should be available for frontend configuration, via advanced options expandable
WEB_CONNECTOR_IGNORED_CLASSES = os.environ.get(
@@ -489,21 +481,3 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
# JWT configuration
JWT_ALGORITHM = "HS256"
# Super Users
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
#####
# API Key Configs
#####
# refers to the rounds described here: https://passlib.readthedocs.io/en/stable/lib/passlib.hash.sha256_crypt.html
_API_KEY_HASH_ROUNDS_RAW = os.environ.get("API_KEY_HASH_ROUNDS")
API_KEY_HASH_ROUNDS = (
int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None
)
POD_NAME = os.environ.get("POD_NAME")
POD_NAMESPACE = os.environ.get("POD_NAMESPACE")

View File

@@ -80,10 +80,6 @@ CELERY_INDEXING_LOCK_TIMEOUT = 60 * 60 # 60 min
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 300 # 5 min
CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
@@ -129,8 +125,6 @@ class DocumentSource(str, Enum):
OCI_STORAGE = "oci_storage"
XENFORO = "xenforo"
NOT_APPLICABLE = "not_applicable"
FRESHDESK = "freshdesk"
FIREFLIES = "fireflies"
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
@@ -213,17 +207,9 @@ class PostgresAdvisoryLocks(Enum):
class DanswerCeleryQueues:
# Light queue
VESPA_METADATA_SYNC = "vespa_metadata_sync"
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
CONNECTOR_DELETION = "connector_deletion"
# Heavy queue
CONNECTOR_PRUNING = "connector_pruning"
CONNECTOR_DOC_PERMISSIONS_SYNC = "connector_doc_permissions_sync"
CONNECTOR_EXTERNAL_GROUP_SYNC = "connector_external_group_sync"
# Indexing queue
CONNECTOR_INDEXING = "connector_indexing"
@@ -233,24 +219,11 @@ class DanswerRedisLocks:
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat"
CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK = (
"da_lock:check_connector_doc_permissions_sync_beat"
)
CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK = (
"da_lock:check_connector_external_group_sync_beat"
)
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
"da_lock:connector_doc_permissions_sync"
)
CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX = "da_lock:connector_external_group_sync"
PRUNING_LOCK_PREFIX = "da_lock:pruning"
INDEXING_METADATA_PREFIX = "da_metadata:indexing"
SLACK_BOT_LOCK = "da_lock:slack_bot"
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
class DanswerCeleryPriority(int, Enum):
HIGHEST = 0

View File

@@ -119,14 +119,3 @@ if _LITELLM_PASS_THROUGH_HEADERS_RAW:
logger.error(
"Failed to parse LITELLM_PASS_THROUGH_HEADERS, must be a valid JSON object"
)
# if specified, will merge the specified JSON with the existing body of the
# request before sending it to the LLM
LITELLM_EXTRA_BODY: dict | None = None
_LITELLM_EXTRA_BODY_RAW = os.environ.get("LITELLM_EXTRA_BODY")
if _LITELLM_EXTRA_BODY_RAW:
try:
LITELLM_EXTRA_BODY = json.loads(_LITELLM_EXTRA_BODY_RAW)
except Exception:
pass

View File

@@ -17,7 +17,6 @@ from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
@@ -92,13 +91,12 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
cql_page_query += f" and id='{page_id}'"
self.cql_page_query = cql_page_query
self.cql_time_filter = ""
self.cql_label_filter = ""
self.cql_time_filter = ""
if labels_to_skip:
labels_to_skip = list(set(labels_to_skip))
comma_separated_labels = ",".join(f"'{label}'" for label in labels_to_skip)
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
comma_separated_labels = ",".join(labels_to_skip)
self.cql_label_filter = f"&label not in ({comma_separated_labels})"
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
@@ -127,8 +125,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
for comment in comments:
comment_string += "\nComment:\n"
comment_string += extract_text_from_confluence_html(
confluence_client=self.confluence_client,
confluence_object=comment,
confluence_client=self.confluence_client, confluence_object=comment
)
return comment_string
@@ -146,7 +143,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
# The url and the id are the same
object_url = build_confluence_document_id(
self.wiki_base, confluence_object["_links"]["webui"], self.is_cloud
self.wiki_base, confluence_object["_links"]["webui"]
)
object_text = None
@@ -250,11 +247,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
self.cql_time_filter += f" and lastmodified <= '{formatted_end_time}'"
return self._fetch_document_batches()
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
@@ -278,9 +271,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
self.wiki_base,
page["_links"]["webui"],
self.is_cloud,
self.wiki_base, page["_links"]["webui"]
),
perm_sync_data=perm_sync_data,
)
@@ -295,9 +286,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
self.wiki_base,
attachment["_links"]["webui"],
self.is_cloud,
self.wiki_base, attachment["_links"]["webui"]
),
perm_sync_data=perm_sync_data,
)

View File

@@ -100,39 +100,6 @@ def extract_text_from_confluence_html(
continue
# Include @ sign for tagging, more clear for LLM
user.replaceWith("@" + _get_user(confluence_client, user_id))
for html_page_reference in soup.findAll("ri:page"):
# Wrap this in a try-except because there are some pages that might not exist
try:
page_title = html_page_reference.attrs["ri:content-title"]
if not page_title:
continue
page_query = f"type=page and title='{page_title}'"
page_contents: dict[str, Any] | None = None
# Confluence enforces title uniqueness, so we should only get one result here
for page_batch in confluence_client.paginated_cql_page_retrieval(
cql=page_query,
expand="body.storage.value",
limit=1,
):
page_contents = page_batch[0]
break
except Exception:
logger.warning(
f"Error getting page contents for object {confluence_object}"
)
continue
if not page_contents:
continue
text_from_page = extract_text_from_confluence_html(
confluence_client, page_contents
)
html_page_reference.replaceWith(text_from_page)
return format_document_soup(soup)
@@ -186,9 +153,7 @@ def attachment_to_content(
return extracted_text
def build_confluence_document_id(
base_url: str, content_url: str, is_cloud: bool
) -> str:
def build_confluence_document_id(base_url: str, content_url: str) -> str:
"""For confluence, the document id is the page url for a page based document
or the attachment download url for an attachment based document
@@ -199,8 +164,6 @@ def build_confluence_document_id(
Returns:
str: The document id
"""
if is_cloud and not base_url.endswith("/wiki"):
base_url += "/wiki"
return f"{base_url}{content_url}"

View File

@@ -23,16 +23,7 @@ def datetime_to_utc(dt: datetime) -> datetime:
def time_str_to_utc(datetime_str: str) -> datetime:
try:
dt = parse(datetime_str)
except ValueError:
# Handle malformed timezone by attempting to fix common format issues
if "0000" in datetime_str:
# Convert "0000" to "+0000" for proper timezone parsing
fixed_dt_str = datetime_str.replace(" 0000", " +0000")
dt = parse(fixed_dt_str)
else:
raise
dt = parse(datetime_str)
return datetime_to_utc(dt)

View File

@@ -16,8 +16,6 @@ from danswer.connectors.discourse.connector import DiscourseConnector
from danswer.connectors.document360.connector import Document360Connector
from danswer.connectors.dropbox.connector import DropboxConnector
from danswer.connectors.file.connector import LocalFileConnector
from danswer.connectors.fireflies.connector import FirefliesConnector
from danswer.connectors.freshdesk.connector import FreshdeskConnector
from danswer.connectors.github.connector import GithubConnector
from danswer.connectors.gitlab.connector import GitlabConnector
from danswer.connectors.gmail.connector import GmailConnector
@@ -101,8 +99,6 @@ def identify_connector_class(
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
DocumentSource.OCI_STORAGE: BlobStorageConnector,
DocumentSource.XENFORO: XenforoConnector,
DocumentSource.FRESHDESK: FreshdeskConnector,
DocumentSource.FIREFLIES: FirefliesConnector,
}
connector_by_source = connector_map.get(source, {})

View File

@@ -27,8 +27,8 @@ from danswer.file_processing.extract_file_text import read_pdf_file
from danswer.file_processing.extract_file_text import read_text_file
from danswer.file_store.file_store import get_default_file_store
from danswer.utils.logger import setup_logger
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -123,13 +123,9 @@ def _process_file(
"filename",
"file_display_name",
"title",
"connector_type",
]
}
source_type_str = all_metadata.get("connector_type")
source_type = DocumentSource(source_type_str) if source_type_str else None
p_owner_names = all_metadata.get("primary_owners")
s_owner_names = all_metadata.get("secondary_owners")
p_owners = (
@@ -149,7 +145,7 @@ def _process_file(
sections=[
Section(link=all_metadata.get("link"), text=file_content_raw.strip())
],
source=source_type or DocumentSource.FILE,
source=DocumentSource.FILE,
semantic_identifier=file_display_name,
title=title,
doc_updated_at=final_time_updated,

View File

@@ -1,182 +0,0 @@
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import List
import requests
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.logger import setup_logger
logger = setup_logger()
_FIREFLIES_ID_PREFIX = "FIREFLIES_"
_FIREFLIES_API_URL = "https://api.fireflies.ai/graphql"
_FIREFLIES_TRANSCRIPT_QUERY_SIZE = 50 # Max page size is 50
_FIREFLIES_API_QUERY = """
query Transcripts($fromDate: DateTime, $toDate: DateTime, $limit: Int!, $skip: Int!) {
transcripts(fromDate: $fromDate, toDate: $toDate, limit: $limit, skip: $skip) {
id
title
host_email
participants
date
transcript_url
sentences {
text
speaker_name
}
}
}
"""
def _create_doc_from_transcript(transcript: dict) -> Document | None:
meeting_text = ""
sentences = transcript.get("sentences", [])
if sentences:
for sentence in sentences:
meeting_text += sentence.get("speaker_name") or "Unknown Speaker"
meeting_text += ": " + sentence.get("text", "") + "\n\n"
else:
return None
meeting_link = transcript["transcript_url"]
fireflies_id = _FIREFLIES_ID_PREFIX + transcript["id"]
meeting_title = transcript["title"] or "No Title"
meeting_date_unix = transcript["date"]
meeting_date = datetime.fromtimestamp(meeting_date_unix / 1000, tz=timezone.utc)
meeting_host_email = transcript["host_email"]
host_email_user_info = [BasicExpertInfo(email=meeting_host_email)]
meeting_participants_email_list = []
for participant in transcript.get("participants", []):
if participant != meeting_host_email and participant:
meeting_participants_email_list.append(BasicExpertInfo(email=participant))
return Document(
id=fireflies_id,
sections=[
Section(
link=meeting_link,
text=meeting_text,
)
],
source=DocumentSource.FIREFLIES,
semantic_identifier=meeting_title,
metadata={},
doc_updated_at=meeting_date,
primary_owners=host_email_user_info,
secondary_owners=meeting_participants_email_list,
)
class FirefliesConnector(PollConnector, LoadConnector):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
def load_credentials(self, credentials: dict[str, str]) -> None:
api_key = credentials.get("fireflies_api_key")
if not isinstance(api_key, str):
raise ConnectorMissingCredentialError(
"The Fireflies API key must be a string"
)
self.api_key = api_key
return None
def _fetch_transcripts(
self, start_datetime: str | None = None, end_datetime: str | None = None
) -> Iterator[List[dict]]:
if self.api_key is None:
raise ConnectorMissingCredentialError("Missing API key")
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + self.api_key,
}
skip = 0
variables: dict[str, int | str] = {
"limit": _FIREFLIES_TRANSCRIPT_QUERY_SIZE,
}
if start_datetime:
variables["fromDate"] = start_datetime
if end_datetime:
variables["toDate"] = end_datetime
while True:
variables["skip"] = skip
response = requests.post(
_FIREFLIES_API_URL,
headers=headers,
json={"query": _FIREFLIES_API_QUERY, "variables": variables},
)
response.raise_for_status()
if response.status_code == 204:
break
recieved_transcripts = response.json()
parsed_transcripts = recieved_transcripts.get("data", {}).get(
"transcripts", []
)
yield parsed_transcripts
if len(parsed_transcripts) < _FIREFLIES_TRANSCRIPT_QUERY_SIZE:
break
skip += _FIREFLIES_TRANSCRIPT_QUERY_SIZE
def _process_transcripts(
self, start: str | None = None, end: str | None = None
) -> GenerateDocumentsOutput:
doc_batch: List[Document] = []
for transcript_batch in self._fetch_transcripts(start, end):
for transcript in transcript_batch:
if doc := _create_doc_from_transcript(transcript):
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
return self._process_transcripts()
def poll_source(
self, start_unixtime: SecondsSinceUnixEpoch, end_unixtime: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.fromtimestamp(
start_unixtime, tz=timezone.utc
).strftime("%Y-%m-%dT%H:%M:%S.000Z")
end_datetime = datetime.fromtimestamp(end_unixtime, tz=timezone.utc).strftime(
"%Y-%m-%dT%H:%M:%S.000Z"
)
yield from self._process_transcripts(start_datetime, end_datetime)

View File

@@ -1,239 +0,0 @@
import json
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import List
import requests
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.html_utils import parse_html_page_basic
from danswer.utils.logger import setup_logger
logger = setup_logger()
_FRESHDESK_ID_PREFIX = "FRESHDESK_"
_TICKET_FIELDS_TO_INCLUDE = {
"fr_escalated",
"spam",
"priority",
"source",
"status",
"type",
"is_escalated",
"tags",
"nr_due_by",
"nr_escalated",
"cc_emails",
"fwd_emails",
"reply_cc_emails",
"ticket_cc_emails",
"support_email",
"to_emails",
}
_SOURCE_NUMBER_TYPE_MAP: dict[int, str] = {
1: "Email",
2: "Portal",
3: "Phone",
7: "Chat",
9: "Feedback Widget",
10: "Outbound Email",
}
_PRIORITY_NUMBER_TYPE_MAP: dict[int, str] = {
1: "low",
2: "medium",
3: "high",
4: "urgent",
}
_STATUS_NUMBER_TYPE_MAP: dict[int, str] = {
2: "open",
3: "pending",
4: "resolved",
5: "closed",
}
def _create_metadata_from_ticket(ticket: dict) -> dict:
metadata: dict[str, str | list[str]] = {}
# Combine all emails into a list so there are no repeated emails
email_data: set[str] = set()
for key, value in ticket.items():
# Skip fields that aren't useful for embedding
if key not in _TICKET_FIELDS_TO_INCLUDE:
continue
# Skip empty fields
if not value or value == "[]":
continue
# Convert strings or lists to strings
stringified_value: str | list[str]
if isinstance(value, list):
stringified_value = [str(item) for item in value]
else:
stringified_value = str(value)
if "email" in key:
if isinstance(stringified_value, list):
email_data.update(stringified_value)
else:
email_data.add(stringified_value)
else:
metadata[key] = stringified_value
if email_data:
metadata["emails"] = list(email_data)
# Convert source numbers to human-parsable string
if source_number := ticket.get("source"):
metadata["source"] = _SOURCE_NUMBER_TYPE_MAP.get(
source_number, "Unknown Source Type"
)
# Convert priority numbers to human-parsable string
if priority_number := ticket.get("priority"):
metadata["priority"] = _PRIORITY_NUMBER_TYPE_MAP.get(
priority_number, "Unknown Priority"
)
# Convert status to human-parsable string
if status_number := ticket.get("status"):
metadata["status"] = _STATUS_NUMBER_TYPE_MAP.get(
status_number, "Unknown Status"
)
due_by = datetime.fromisoformat(ticket["due_by"].replace("Z", "+00:00"))
metadata["overdue"] = str(datetime.now(timezone.utc) > due_by)
return metadata
def _create_doc_from_ticket(ticket: dict, domain: str) -> Document:
# Use the ticket description as the text
text = f"Ticket description: {parse_html_page_basic(ticket.get('description_text', ''))}"
metadata = _create_metadata_from_ticket(ticket)
# This is also used in the ID because it is more unique than the just the ticket ID
link = f"https://{domain}.freshdesk.com/helpdesk/tickets/{ticket['id']}"
return Document(
id=_FRESHDESK_ID_PREFIX + link,
sections=[
Section(
link=link,
text=text,
)
],
source=DocumentSource.FRESHDESK,
semantic_identifier=ticket["subject"],
metadata=metadata,
doc_updated_at=datetime.fromisoformat(
ticket["updated_at"].replace("Z", "+00:00")
),
)
class FreshdeskConnector(PollConnector, LoadConnector):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
def load_credentials(self, credentials: dict[str, str | int]) -> None:
api_key = credentials.get("freshdesk_api_key")
domain = credentials.get("freshdesk_domain")
password = credentials.get("freshdesk_password")
if not all(isinstance(cred, str) for cred in [domain, api_key, password]):
raise ConnectorMissingCredentialError(
"All Freshdesk credentials must be strings"
)
self.api_key = str(api_key)
self.domain = str(domain)
self.password = str(password)
def _fetch_tickets(
self, start: datetime | None = None, end: datetime | None = None
) -> Iterator[List[dict]]:
"""
'end' is not currently used, so we may double fetch tickets created after the indexing
starts but before the actual call is made.
To use 'end' would require us to use the search endpoint but it has limitations,
namely having to fetch all IDs and then individually fetch each ticket because there is no
'include' field available for this endpoint:
https://developers.freshdesk.com/api/#filter_tickets
"""
if self.api_key is None or self.domain is None or self.password is None:
raise ConnectorMissingCredentialError("freshdesk")
base_url = f"https://{self.domain}.freshdesk.com/api/v2/tickets"
params: dict[str, int | str] = {
"include": "description",
"per_page": 50,
"page": 1,
}
if start:
params["updated_since"] = start.isoformat()
while True:
response = requests.get(
base_url, auth=(self.api_key, self.password), params=params
)
response.raise_for_status()
if response.status_code == 204:
break
tickets = json.loads(response.content)
logger.info(
f"Fetched {len(tickets)} tickets from Freshdesk API (Page {params['page']})"
)
yield tickets
if len(tickets) < int(params["per_page"]):
break
params["page"] = int(params["page"]) + 1
def _process_tickets(
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
doc_batch: List[Document] = []
for ticket_batch in self._fetch_tickets(start, end):
for ticket in ticket_batch:
doc_batch.append(_create_doc_from_ticket(ticket, self.domain))
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
return self._process_tickets()
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
yield from self._process_tickets(start_datetime, end_datetime)

View File

@@ -1,361 +1,221 @@
from base64 import urlsafe_b64decode
from typing import Any
from typing import cast
from typing import Dict
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient import discovery # type: ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.google_utils.google_auth import get_google_creds
from danswer.connectors.google_utils.google_utils import execute_paginated_retrieval
from danswer.connectors.google_utils.resources import get_admin_service
from danswer.connectors.google_utils.resources import get_gmail_service
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
from danswer.connectors.gmail.connector_auth import (
get_gmail_creds_for_authorized_user,
)
from danswer.connectors.gmail.connector_auth import (
get_gmail_creds_for_service_account,
)
from danswer.connectors.gmail.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
from danswer.connectors.gmail.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.connectors.gmail.constants import (
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_utils.shared_constants import MISSING_SCOPES_ERROR_STR
from danswer.connectors.google_utils.shared_constants import ONYX_SCOPE_INSTRUCTIONS
from danswer.connectors.google_utils.shared_constants import SLIM_BATCH_SIZE
from danswer.connectors.google_utils.shared_constants import USER_FIELDS
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.models import SlimDocument
from danswer.utils.logger import setup_logger
from danswer.utils.retry_wrapper import retry_builder
logger = setup_logger()
# This is for the initial list call to get the thread ids
THREAD_LIST_FIELDS = "nextPageToken, threads(id)"
# These are the fields to retrieve using the ID from the initial list call
PARTS_FIELDS = "parts(body(data), mimeType)"
PAYLOAD_FIELDS = f"payload(headers, {PARTS_FIELDS})"
MESSAGES_FIELDS = f"messages(id, {PAYLOAD_FIELDS})"
THREADS_FIELDS = f"threads(id, {MESSAGES_FIELDS})"
THREAD_FIELDS = f"id, {MESSAGES_FIELDS}"
EMAIL_FIELDS = [
"cc",
"bcc",
"from",
"to",
]
add_retries = retry_builder(tries=50, max_delay=30)
def _build_time_range_query(
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
) -> str | None:
query = ""
if time_range_start is not None and time_range_start != 0:
query += f"after:{int(time_range_start)}"
if time_range_end is not None and time_range_end != 0:
query += f" before:{int(time_range_end)}"
query = query.strip()
if len(query) == 0:
return None
return query
def _clean_email_and_extract_name(email: str) -> tuple[str, str | None]:
email = email.strip()
if "<" in email and ">" in email:
# Handle format: "Display Name <email@domain.com>"
display_name = email[: email.find("<")].strip()
email_address = email[email.find("<") + 1 : email.find(">")].strip()
return email_address, display_name if display_name else None
else:
# Handle plain email address
return email.strip(), None
def _get_owners_from_emails(emails: dict[str, str | None]) -> list[BasicExpertInfo]:
owners = []
for email, names in emails.items():
if names:
name_parts = names.split(" ")
first_name = " ".join(name_parts[:-1])
last_name = name_parts[-1]
else:
first_name = None
last_name = None
owners.append(
BasicExpertInfo(email=email, first_name=first_name, last_name=last_name)
)
return owners
def _get_message_body(payload: dict[str, Any]) -> str:
parts = payload.get("parts", [])
message_body = ""
for part in parts:
mime_type = part.get("mimeType")
body = part.get("body")
if mime_type == "text/plain" and body:
data = body.get("data", "")
text = urlsafe_b64decode(data).decode()
message_body += text
return message_body
def message_to_section(message: Dict[str, Any]) -> tuple[Section, dict[str, str]]:
link = f"https://mail.google.com/mail/u/0/#inbox/{message['id']}"
payload = message.get("payload", {})
headers = payload.get("headers", [])
metadata: dict[str, Any] = {}
for header in headers:
name = header.get("name").lower()
value = header.get("value")
if name in EMAIL_FIELDS:
metadata[name] = value
if name == "subject":
metadata["subject"] = value
if name == "date":
metadata["updated_at"] = value
if labels := message.get("labelIds"):
metadata["labels"] = labels
message_data = ""
for name, value in metadata.items():
# updated at isnt super useful for the llm
if name != "updated_at":
message_data += f"{name}: {value}\n"
message_body_text: str = _get_message_body(payload)
return Section(link=link, text=message_body_text + message_data), metadata
def thread_to_document(full_thread: Dict[str, Any]) -> Document | None:
all_messages = full_thread.get("messages", [])
if not all_messages:
return None
sections = []
semantic_identifier = ""
updated_at = None
from_emails: dict[str, str | None] = {}
other_emails: dict[str, str | None] = {}
for message in all_messages:
section, message_metadata = message_to_section(message)
sections.append(section)
for name, value in message_metadata.items():
if name in EMAIL_FIELDS:
email, display_name = _clean_email_and_extract_name(value)
if name == "from":
from_emails[email] = (
display_name if not from_emails.get(email) else None
)
else:
other_emails[email] = (
display_name if not other_emails.get(email) else None
)
# If we haven't set the semantic identifier yet, set it to the subject of the first message
if not semantic_identifier:
semantic_identifier = message_metadata.get("subject", "")
if message_metadata.get("updated_at"):
updated_at = message_metadata.get("updated_at")
updated_at_datetime = None
if updated_at:
updated_at_datetime = time_str_to_utc(updated_at)
id = full_thread.get("id")
if not id:
raise ValueError("Thread ID is required")
primary_owners = _get_owners_from_emails(from_emails)
secondary_owners = _get_owners_from_emails(other_emails)
return Document(
id=id,
semantic_identifier=semantic_identifier,
sections=sections,
source=DocumentSource.GMAIL,
# This is used to perform permission sync
primary_owners=primary_owners,
secondary_owners=secondary_owners,
doc_updated_at=updated_at_datetime,
# Not adding emails to metadata because it's already in the sections
metadata={},
)
class GmailConnector(LoadConnector, PollConnector, SlimConnector):
class GmailConnector(LoadConnector, PollConnector):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
self._primary_admin_email: str | None = None
@property
def primary_admin_email(self) -> str:
if self._primary_admin_email is None:
raise RuntimeError(
"Primary admin email missing, "
"should not call this property "
"before calling load_credentials"
)
return self._primary_admin_email
@property
def google_domain(self) -> str:
if self._primary_admin_email is None:
raise RuntimeError(
"Primary admin email missing, "
"should not call this property "
"before calling load_credentials"
)
return self._primary_admin_email.split("@")[-1]
@property
def creds(self) -> OAuthCredentials | ServiceAccountCredentials:
if self._creds is None:
raise RuntimeError(
"Creds missing, "
"should not call this property "
"before calling load_credentials"
)
return self._creds
self.creds: OAuthCredentials | ServiceAccountCredentials | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
self._primary_admin_email = primary_admin_email
"""Checks for two different types of credentials.
(1) A credential which holds a token acquired via a user going thorugh
the Google OAuth flow.
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.
"""
creds: OAuthCredentials | ServiceAccountCredentials | None = None
new_creds_dict = None
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
access_token_json_str = cast(
str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]
)
creds = get_gmail_creds_for_authorized_user(
token_json_str=access_token_json_str
)
self._creds, new_creds_dict = get_google_creds(
credentials=credentials,
source=DocumentSource.GMAIL,
)
# tell caller to update token stored in DB if it has changed
# (e.g. the token has been refreshed)
new_creds_json_str = creds.to_json() if creds else ""
if new_creds_json_str != access_token_json_str:
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
if GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
service_account_key_json_str = credentials[
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
]
creds = get_gmail_creds_for_service_account(
service_account_key_json_str=service_account_key_json_str
)
# "Impersonate" a user if one is specified
delegated_user_email = cast(
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
)
if delegated_user_email:
creds = creds.with_subject(delegated_user_email) if creds else None # type: ignore
if creds is None:
raise PermissionError(
"Unable to access Gmail - unknown credential structure."
)
self.creds = creds
return new_creds_dict
def _get_all_user_emails(self) -> list[str]:
admin_service = get_admin_service(self.creds, self.primary_admin_email)
emails = []
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
):
if email := user.get("primaryEmail"):
emails.append(email)
return emails
def _get_email_body(self, payload: dict[str, Any]) -> str:
parts = payload.get("parts", [])
email_body = ""
for part in parts:
mime_type = part.get("mimeType")
body = part.get("body")
if mime_type == "text/plain":
data = body.get("data", "")
text = urlsafe_b64decode(data).decode()
email_body += text
return email_body
def _fetch_threads(
def _email_to_document(self, full_email: Dict[str, Any]) -> Document:
email_id = full_email["id"]
payload = full_email["payload"]
headers = payload.get("headers")
labels = full_email.get("labelIds", [])
metadata = {}
if headers:
for header in headers:
name = header.get("name").lower()
value = header.get("value")
if name in ["from", "to", "subject", "date", "cc", "bcc"]:
metadata[name] = value
email_data = ""
for name, value in metadata.items():
email_data += f"{name}: {value}\n"
metadata["labels"] = labels
logger.debug(f"{email_data}")
email_body_text: str = self._get_email_body(payload)
date_str = metadata.get("date")
email_updated_at = time_str_to_utc(date_str) if date_str else None
link = f"https://mail.google.com/mail/u/0/#inbox/{email_id}"
return Document(
id=email_id,
sections=[Section(link=link, text=email_data + email_body_text)],
source=DocumentSource.GMAIL,
title=metadata.get("subject"),
semantic_identifier=metadata.get("subject", "Untitled Email"),
doc_updated_at=email_updated_at,
metadata=metadata,
)
@staticmethod
def _build_time_range_query(
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
) -> str | None:
query = ""
if time_range_start is not None and time_range_start != 0:
query += f"after:{int(time_range_start)}"
if time_range_end is not None and time_range_end != 0:
query += f" before:{int(time_range_end)}"
query = query.strip()
if len(query) == 0:
return None
return query
def _fetch_mails_from_gmail(
self,
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
query = _build_time_range_query(time_range_start, time_range_end)
doc_batch = []
for user_email in self._get_all_user_emails():
gmail_service = get_gmail_service(self.creds, user_email)
for thread in execute_paginated_retrieval(
retrieval_function=gmail_service.users().threads().list,
list_key="threads",
userId=user_email,
fields=THREAD_LIST_FIELDS,
q=query,
):
full_threads = execute_paginated_retrieval(
retrieval_function=gmail_service.users().threads().get,
list_key=None,
userId=user_email,
fields=THREAD_FIELDS,
id=thread["id"],
if self.creds is None:
raise PermissionError("Not logged into Gmail")
page_token = ""
query = GmailConnector._build_time_range_query(time_range_start, time_range_end)
service = discovery.build("gmail", "v1", credentials=self.creds)
while page_token is not None:
result = (
service.users()
.messages()
.list(
userId="me",
pageToken=page_token,
q=query,
maxResults=self.batch_size,
)
# full_threads is an iterator containing a single thread
# so we need to convert it to a list and grab the first element
full_thread = list(full_threads)[0]
doc = thread_to_document(full_thread)
if doc is None:
continue
.execute()
)
page_token = result.get("nextPageToken")
messages = result.get("messages", [])
doc_batch = []
for message in messages:
message_id = message["id"]
msg = (
service.users()
.messages()
.get(userId="me", id=message_id, format="full")
.execute()
)
doc = self._email_to_document(msg)
doc_batch.append(doc)
if len(doc_batch) > self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
def _fetch_slim_threads(
self,
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
query = _build_time_range_query(time_range_start, time_range_end)
doc_batch = []
for user_email in self._get_all_user_emails():
logger.info(f"Fetching slim threads for user: {user_email}")
gmail_service = get_gmail_service(self.creds, user_email)
for thread in execute_paginated_retrieval(
retrieval_function=gmail_service.users().threads().list,
list_key="threads",
userId=user_email,
fields=THREAD_LIST_FIELDS,
q=query,
):
doc_batch.append(
SlimDocument(
id=thread["id"],
perm_sync_data={"user_email": user_email},
)
)
if len(doc_batch) > SLIM_BATCH_SIZE:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
if len(doc_batch) > 0:
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
try:
yield from self._fetch_threads()
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
yield from self._fetch_mails_from_gmail()
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
try:
yield from self._fetch_threads(start, end)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
try:
yield from self._fetch_slim_threads(start, end)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
yield from self._fetch_mails_from_gmail(start, end)
if __name__ == "__main__":
pass
import json
import os
service_account_json_path = os.environ.get("GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH")
if not service_account_json_path:
raise ValueError(
"Please set GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH environment variable"
)
with open(service_account_json_path) as f:
creds = json.load(f)
credentials_dict = {
DB_CREDENTIALS_DICT_TOKEN_KEY: json.dumps(creds),
}
delegated_user = os.environ.get("GMAIL_DELEGATED_USER")
if delegated_user:
credentials_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user
connector = GmailConnector()
connector.load_credentials(
json.loads(credentials_dict[DB_CREDENTIALS_DICT_TOKEN_KEY])
)
document_batch_generator = connector.load_from_state()
for document_batch in document_batch_generator:
print(document_batch)
break

View File

@@ -0,0 +1,197 @@
import json
from typing import cast
from urllib.parse import parse_qs
from urllib.parse import ParseResult
from urllib.parse import urlparse
from google.auth.transport.requests import Request # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from sqlalchemy.orm import Session
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import KV_CRED_KEY
from danswer.configs.constants import KV_GMAIL_CRED_KEY
from danswer.configs.constants import KV_GMAIL_SERVICE_ACCOUNT_KEY
from danswer.connectors.gmail.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
from danswer.connectors.gmail.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.connectors.gmail.constants import (
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.gmail.constants import SCOPES
from danswer.db.credentials import update_credential_json
from danswer.db.models import User
from danswer.key_value_store.factory import get_kv_store
from danswer.server.documents.models import CredentialBase
from danswer.server.documents.models import GoogleAppCredentials
from danswer.server.documents.models import GoogleServiceAccountKey
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _build_frontend_gmail_redirect() -> str:
return f"{WEB_DOMAIN}/admin/connectors/gmail/auth/callback"
def get_gmail_creds_for_authorized_user(
token_json_str: str,
) -> OAuthCredentials | None:
creds_json = json.loads(token_json_str)
creds = OAuthCredentials.from_authorized_user_info(creds_json, SCOPES)
if creds.valid:
return creds
if creds.expired and creds.refresh_token:
try:
creds.refresh(Request())
if creds.valid:
logger.notice("Refreshed Gmail tokens.")
return creds
except Exception as e:
logger.exception(f"Failed to refresh gmail access token due to: {e}")
return None
return None
def get_gmail_creds_for_service_account(
service_account_key_json_str: str,
) -> ServiceAccountCredentials | None:
service_account_key = json.loads(service_account_key_json_str)
creds = ServiceAccountCredentials.from_service_account_info(
service_account_key, scopes=SCOPES
)
if not creds.valid or not creds.expired:
creds.refresh(Request())
return creds if creds.valid else None
def verify_csrf(credential_id: int, state: str) -> None:
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
if csrf != state:
raise PermissionError(
"State from Gmail Connector callback does not match expected"
)
def get_gmail_auth_url(credential_id: int) -> str:
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
scopes=SCOPES,
redirect_uri=_build_frontend_gmail_redirect(),
)
auth_url, _ = flow.authorization_url(prompt="consent")
parsed_url = cast(ParseResult, urlparse(auth_url))
params = parse_qs(parsed_url.query)
get_kv_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
return str(auth_url)
def get_auth_url(credential_id: int) -> str:
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
scopes=SCOPES,
redirect_uri=_build_frontend_gmail_redirect(),
)
auth_url, _ = flow.authorization_url(prompt="consent")
parsed_url = cast(ParseResult, urlparse(auth_url))
params = parse_qs(parsed_url.query)
get_kv_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
return str(auth_url)
def update_gmail_credential_access_tokens(
auth_code: str,
credential_id: int,
user: User,
db_session: Session,
) -> OAuthCredentials | None:
app_credentials = get_google_app_gmail_cred()
flow = InstalledAppFlow.from_client_config(
app_credentials.model_dump(),
scopes=SCOPES,
redirect_uri=_build_frontend_gmail_redirect(),
)
flow.fetch_token(code=auth_code)
creds = flow.credentials
token_json_str = creds.to_json()
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str}
if not update_credential_json(credential_id, new_creds_dict, user, db_session):
return None
return creds
def build_service_account_creds(
delegated_user_email: str | None = None,
) -> CredentialBase:
service_account_key = get_gmail_service_account_key()
credential_dict = {
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: service_account_key.json(),
}
if delegated_user_email:
credential_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user_email
return CredentialBase(
source=DocumentSource.GMAIL,
credential_json=credential_dict,
admin_public=True,
)
def get_google_app_gmail_cred() -> GoogleAppCredentials:
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
return GoogleAppCredentials(**json.loads(creds_str))
def upsert_google_app_gmail_cred(app_credentials: GoogleAppCredentials) -> None:
get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True)
def delete_google_app_gmail_cred() -> None:
get_kv_store().delete(KV_GMAIL_CRED_KEY)
def get_gmail_service_account_key() -> GoogleServiceAccountKey:
creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
return GoogleServiceAccountKey(**json.loads(creds_str))
def upsert_gmail_service_account_key(
service_account_key: GoogleServiceAccountKey,
) -> None:
get_kv_store().store(
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
get_kv_store().store(
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
def delete_gmail_service_account_key() -> None:
get_kv_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
def delete_service_account_key() -> None:
get_kv_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)

View File

@@ -0,0 +1,4 @@
DB_CREDENTIALS_DICT_TOKEN_KEY = "gmail_tokens"
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "gmail_service_account_key"
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "gmail_delegated_user"
SCOPES = ["https://www.googleapis.com/auth/gmail.readonly"]

View File

@@ -1,437 +1,556 @@
from collections.abc import Callable
import io
from collections.abc import Iterator
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from collections.abc import Sequence
from datetime import datetime
from datetime import timezone
from enum import Enum
from itertools import chain
from typing import Any
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient import discovery # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import GOOGLE_DRIVE_FOLLOW_SHORTCUTS
from danswer.configs.app_configs import GOOGLE_DRIVE_INCLUDE_SHARED
from danswer.configs.app_configs import GOOGLE_DRIVE_ONLY_ORG_PUBLIC
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.google_drive.doc_conversion import build_slim_document
from danswer.connectors.google_drive.doc_conversion import (
convert_drive_item_to_document,
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.connectors.google_drive.connector_auth import get_google_drive_creds
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
from danswer.connectors.google_drive.file_retrieval import crawl_folders_for_files
from danswer.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
from danswer.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.connectors.google_utils.google_auth import get_google_creds
from danswer.connectors.google_utils.google_utils import execute_paginated_retrieval
from danswer.connectors.google_utils.resources import get_admin_service
from danswer.connectors.google_utils.resources import get_drive_service
from danswer.connectors.google_utils.resources import get_google_docs_service
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_utils.shared_constants import MISSING_SCOPES_ERROR_STR
from danswer.connectors.google_utils.shared_constants import ONYX_SCOPE_INSTRUCTIONS
from danswer.connectors.google_utils.shared_constants import SCOPE_DOC_URL
from danswer.connectors.google_utils.shared_constants import SLIM_BATCH_SIZE
from danswer.connectors.google_utils.shared_constants import USER_FIELDS
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import docx_to_text
from danswer.file_processing.extract_file_text import pptx_to_text
from danswer.file_processing.extract_file_text import read_pdf_file
from danswer.file_processing.unstructured import get_unstructured_api_key
from danswer.file_processing.unstructured import unstructured_to_text
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
from danswer.utils.retry_wrapper import retry_builder
logger = setup_logger()
# TODO: Improve this by using the batch utility: https://googleapis.github.io/google-api-python-client/docs/batch.html
# All file retrievals could be batched and made at once
DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder"
DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut"
UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now
def _extract_str_list_from_comma_str(string: str | None) -> list[str]:
if not string:
return []
return [s.strip() for s in string.split(",") if s.strip()]
class GDriveMimeType(str, Enum):
DOC = "application/vnd.google-apps.document"
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
PDF = "application/pdf"
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
PPT = "application/vnd.google-apps.presentation"
POWERPOINT = (
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
)
PLAIN_TEXT = "text/plain"
MARKDOWN = "text/markdown"
def _extract_ids_from_urls(urls: list[str]) -> list[str]:
return [url.split("/")[-1] for url in urls]
GoogleDriveFileType = dict[str, Any]
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. Trying to combat here by adding a very
# long retry period (~20 minutes of trying every minute)
add_retries = retry_builder(tries=50, max_delay=30)
def _convert_single_file(
creds: Any, primary_admin_email: str, file: dict[str, Any]
) -> Any:
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
user_drive_service = get_drive_service(creds, user_email=user_email)
docs_service = get_google_docs_service(creds, user_email=user_email)
return convert_drive_item_to_document(
file=file,
drive_service=user_drive_service,
docs_service=docs_service,
def _run_drive_file_query(
service: discovery.Resource,
query: str,
continue_on_failure: bool,
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
batch_size: int = INDEX_BATCH_SIZE,
) -> Iterator[GoogleDriveFileType]:
next_page_token = ""
while next_page_token is not None:
logger.debug(f"Running Google Drive fetch with query: {query}")
results = add_retries(
lambda: (
service.files()
.list(
corpora="allDrives"
if include_shared
else "user", # needed to search through shared drives
pageSize=batch_size,
supportsAllDrives=include_shared,
includeItemsFromAllDrives=include_shared,
fields=(
"nextPageToken, files(mimeType, id, name, permissions, "
"modifiedTime, webViewLink, shortcutDetails)"
),
pageToken=next_page_token,
q=query,
)
.execute()
)
)()
next_page_token = results.get("nextPageToken")
files = results["files"]
for file in files:
if follow_shortcuts and "shortcutDetails" in file:
try:
file_shortcut_points_to = add_retries(
lambda: (
service.files()
.get(
fileId=file["shortcutDetails"]["targetId"],
supportsAllDrives=include_shared,
fields="mimeType, id, name, modifiedTime, webViewLink, permissions, shortcutDetails",
)
.execute()
)
)()
yield file_shortcut_points_to
except HttpError:
logger.error(
f"Failed to follow shortcut with details: {file['shortcutDetails']}"
)
if continue_on_failure:
continue
raise
else:
yield file
def _get_folder_id(
service: discovery.Resource,
parent_id: str,
folder_name: str,
include_shared: bool,
follow_shortcuts: bool,
) -> str | None:
"""
Get the ID of a folder given its name and the ID of its parent folder.
"""
query = f"'{parent_id}' in parents and name='{folder_name}' and "
if follow_shortcuts:
query += f"(mimeType='{DRIVE_FOLDER_TYPE}' or mimeType='{DRIVE_SHORTCUT_TYPE}')"
else:
query += f"mimeType='{DRIVE_FOLDER_TYPE}'"
# TODO: support specifying folder path in shared drive rather than just `My Drive`
results = add_retries(
lambda: (
service.files()
.list(
q=query,
spaces="drive",
fields="nextPageToken, files(id, name, shortcutDetails)",
supportsAllDrives=include_shared,
includeItemsFromAllDrives=include_shared,
)
.execute()
)
)()
items = results.get("files", [])
folder_id = None
if items:
if follow_shortcuts and "shortcutDetails" in items[0]:
folder_id = items[0]["shortcutDetails"]["targetId"]
else:
folder_id = items[0]["id"]
return folder_id
def _get_folders(
service: discovery.Resource,
continue_on_failure: bool,
folder_id: str | None = None, # if specified, only fetches files within this folder
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
batch_size: int = INDEX_BATCH_SIZE,
) -> Iterator[GoogleDriveFileType]:
query = f"mimeType = '{DRIVE_FOLDER_TYPE}' "
if follow_shortcuts:
query = "(" + query + f" or mimeType = '{DRIVE_SHORTCUT_TYPE}'" + ") "
if folder_id:
query += f"and '{folder_id}' in parents "
query = query.rstrip() # remove the trailing space(s)
for file in _run_drive_file_query(
service=service,
query=query,
continue_on_failure=continue_on_failure,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
):
# Need to check this since file may have been a target of a shortcut
# and not necessarily a folder
if file["mimeType"] == DRIVE_FOLDER_TYPE:
yield file
else:
pass
def _get_files(
service: discovery.Resource,
continue_on_failure: bool,
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
folder_id: str | None = None, # if specified, only fetches files within this folder
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
batch_size: int = INDEX_BATCH_SIZE,
) -> Iterator[GoogleDriveFileType]:
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' "
if time_range_start is not None:
time_start = datetime.utcfromtimestamp(time_range_start).isoformat() + "Z"
query += f"and modifiedTime >= '{time_start}' "
if time_range_end is not None:
time_stop = datetime.utcfromtimestamp(time_range_end).isoformat() + "Z"
query += f"and modifiedTime <= '{time_stop}' "
if folder_id:
query += f"and '{folder_id}' in parents "
query = query.rstrip() # remove the trailing space(s)
files = _run_drive_file_query(
service=service,
query=query,
continue_on_failure=continue_on_failure,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
)
def _process_files_batch(
files: list[GoogleDriveFileType], convert_func: Callable, batch_size: int
) -> GenerateDocumentsOutput:
doc_batch = []
with ThreadPoolExecutor(max_workers=min(16, len(files))) as executor:
for doc in executor.map(convert_func, files):
if doc:
doc_batch.append(doc)
if len(doc_batch) >= batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
return files
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
include_shared_drives: bool = True,
shared_drive_urls: str | None = None,
include_my_drives: bool = True,
my_drive_emails: str | None = None,
shared_folder_urls: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
# OLD PARAMETERS
folder_paths: list[str] | None = None,
include_shared: bool | None = None,
follow_shortcuts: bool | None = None,
only_org_public: bool | None = None,
continue_on_failure: bool | None = None,
) -> None:
# Check for old input parameters
if (
folder_paths is not None
or include_shared is not None
or follow_shortcuts is not None
or only_org_public is not None
or continue_on_failure is not None
):
logger.exception(
"Google Drive connector received old input parameters. "
"Please visit the docs for help with the new setup: "
f"{SCOPE_DOC_URL}"
)
raise ValueError(
"Google Drive connector received old input parameters. "
"Please visit the docs for help with the new setup: "
f"{SCOPE_DOC_URL}"
)
def get_all_files_batched(
service: discovery.Resource,
continue_on_failure: bool,
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
batch_size: int = INDEX_BATCH_SIZE,
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
folder_id: str | None = None, # if specified, only fetches files within this folder
# if True, will fetch files in sub-folders of the specified folder ID.
# Only applies if folder_id is specified.
traverse_subfolders: bool = True,
folder_ids_traversed: list[str] | None = None,
) -> Iterator[list[GoogleDriveFileType]]:
"""Gets all files matching the criteria specified by the args from Google Drive
in batches of size `batch_size`.
"""
found_files = _get_files(
service=service,
continue_on_failure=continue_on_failure,
time_range_start=time_range_start,
time_range_end=time_range_end,
folder_id=folder_id,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
)
yield from batch_generator(
items=found_files,
batch_size=batch_size,
pre_batch_yield=lambda batch_files: logger.debug(
f"Parseable Documents in batch: {[file['name'] for file in batch_files]}"
),
)
if (
not include_shared_drives
and not include_my_drives
and not shared_folder_urls
):
raise ValueError(
"At least one of include_shared_drives, include_my_drives,"
" or shared_folder_urls must be true"
)
self.batch_size = batch_size
self.include_shared_drives = include_shared_drives
shared_drive_url_list = _extract_str_list_from_comma_str(shared_drive_urls)
self._requested_shared_drive_ids = set(
_extract_ids_from_urls(shared_drive_url_list)
if traverse_subfolders and folder_id is not None:
folder_ids_traversed = folder_ids_traversed or []
subfolders = _get_folders(
service=service,
folder_id=folder_id,
continue_on_failure=continue_on_failure,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
)
self.include_my_drives = include_my_drives
self._requested_my_drive_emails = set(
_extract_str_list_from_comma_str(my_drive_emails)
)
shared_folder_url_list = _extract_str_list_from_comma_str(shared_folder_urls)
self._requested_folder_ids = set(_extract_ids_from_urls(shared_folder_url_list))
self._primary_admin_email: str | None = None
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
self._retrieved_ids: set[str] = set()
@property
def primary_admin_email(self) -> str:
if self._primary_admin_email is None:
raise RuntimeError(
"Primary admin email missing, "
"should not call this property "
"before calling load_credentials"
)
return self._primary_admin_email
@property
def google_domain(self) -> str:
if self._primary_admin_email is None:
raise RuntimeError(
"Primary admin email missing, "
"should not call this property "
"before calling load_credentials"
)
return self._primary_admin_email.split("@")[-1]
@property
def creds(self) -> OAuthCredentials | ServiceAccountCredentials:
if self._creds is None:
raise RuntimeError(
"Creds missing, "
"should not call this property "
"before calling load_credentials"
)
return self._creds
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
self._primary_admin_email = primary_admin_email
self._creds, new_creds_dict = get_google_creds(
credentials=credentials,
source=DocumentSource.GOOGLE_DRIVE,
)
return new_creds_dict
def _update_traversed_parent_ids(self, folder_id: str) -> None:
self._retrieved_ids.add(folder_id)
def _get_all_user_emails(self) -> list[str]:
# Start with primary admin email
user_emails = [self.primary_admin_email]
# Only fetch additional users if using service account
if isinstance(self.creds, OAuthCredentials):
return user_emails
admin_service = get_admin_service(
creds=self.creds,
user_email=self.primary_admin_email,
)
# Get admins first since they're more likely to have access to most files
for is_admin in [True, False]:
query = "isAdmin=true" if is_admin else "isAdmin=false"
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
query=query,
):
if email := user.get("primaryEmail"):
if email not in user_emails:
user_emails.append(email)
return user_emails
def _get_all_drive_ids(self) -> set[str]:
primary_drive_service = get_drive_service(
creds=self.creds,
user_email=self.primary_admin_email,
)
all_drive_ids = set()
# We don't want to fail if we're using OAuth because you can
# access your my drive as a non admin user in an org still
ignore_fetch_failure = isinstance(self.creds, OAuthCredentials)
for drive in execute_paginated_retrieval(
retrieval_function=primary_drive_service.drives().list,
list_key="drives",
continue_on_404_or_403=ignore_fetch_failure,
useDomainAdminAccess=True,
fields="drives(id)",
):
all_drive_ids.add(drive["id"])
if not all_drive_ids:
logger.warning(
"No drives found. This is likely because oauth user "
"is not an admin and cannot view all drive IDs. "
"Continuing with only the shared drive IDs specified in the config."
)
all_drive_ids = set(self._requested_shared_drive_ids)
return all_drive_ids
def _impersonate_user_for_retrieval(
self,
user_email: str,
is_slim: bool,
filtered_drive_ids: set[str],
filtered_folder_ids: set[str],
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
drive_service = get_drive_service(self.creds, user_email)
# if we are including my drives, try to get the current user's my
# drive if any of the following are true:
# - no specific emails were requested
# - the current user's email is in the requested emails
# - we are using OAuth (in which case we assume that is the only email we will try)
if self.include_my_drives and (
not self._requested_my_drive_emails
or user_email in self._requested_my_drive_emails
or isinstance(self.creds, OAuthCredentials)
):
yield from get_all_files_in_my_drive(
service=drive_service,
update_traversed_ids_func=self._update_traversed_parent_ids,
is_slim=is_slim,
start=start,
end=end,
)
remaining_drive_ids = filtered_drive_ids - self._retrieved_ids
for drive_id in remaining_drive_ids:
yield from get_files_in_shared_drive(
service=drive_service,
drive_id=drive_id,
is_slim=is_slim,
update_traversed_ids_func=self._update_traversed_parent_ids,
start=start,
end=end,
)
remaining_folders = filtered_folder_ids - self._retrieved_ids
for folder_id in remaining_folders:
yield from crawl_folders_for_files(
service=drive_service,
parent_id=folder_id,
traversed_parent_ids=self._retrieved_ids,
update_traversed_ids_func=self._update_traversed_parent_ids,
start=start,
end=end,
)
def _fetch_drive_items(
self,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
all_org_emails: list[str] = self._get_all_user_emails()
all_drive_ids: set[str] = self._get_all_drive_ids()
# remove drive ids from the folder ids because they are queried differently
filtered_folder_ids = self._requested_folder_ids - all_drive_ids
# Remove drive_ids that are not in the all_drive_ids and check them as folders instead
invalid_drive_ids = self._requested_shared_drive_ids - all_drive_ids
if invalid_drive_ids:
logger.warning(
f"Some shared drive IDs were not found. IDs: {invalid_drive_ids}"
)
logger.warning("Checking for folder access instead...")
filtered_folder_ids.update(invalid_drive_ids)
# If including shared drives, use the requested IDs if provided,
# otherwise use all drive IDs
filtered_drive_ids = set()
if self.include_shared_drives:
if self._requested_shared_drive_ids:
# Remove invalid drive IDs from requested IDs
filtered_drive_ids = (
self._requested_shared_drive_ids - invalid_drive_ids
for subfolder in subfolders:
if subfolder["id"] not in folder_ids_traversed:
logger.info("Fetching all files in subfolder: " + subfolder["name"])
folder_ids_traversed.append(subfolder["id"])
yield from get_all_files_batched(
service=service,
continue_on_failure=continue_on_failure,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
batch_size=batch_size,
time_range_start=time_range_start,
time_range_end=time_range_end,
folder_id=subfolder["id"],
traverse_subfolders=traverse_subfolders,
folder_ids_traversed=folder_ids_traversed,
)
else:
filtered_drive_ids = all_drive_ids
logger.debug(
"Skipping subfolder since already traversed: " + subfolder["name"]
)
# Process users in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=10) as executor:
future_to_email = {
executor.submit(
self._impersonate_user_for_retrieval,
email,
is_slim,
filtered_drive_ids,
filtered_folder_ids,
start,
end,
): email
for email in all_org_emails
}
# Yield results as they complete
for future in as_completed(future_to_email):
yield from future.result()
def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
mime_type = file["mimeType"]
remaining_folders = (
filtered_drive_ids | filtered_folder_ids
) - self._retrieved_ids
if remaining_folders:
logger.warning(
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"
if mime_type not in set(item.value for item in GDriveMimeType):
# Unsupported file types can still have a title, finding this way is still useful
return UNSUPPORTED_FILE_TYPE_CONTENT
if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
GDriveMimeType.SPREADSHEET.value,
]:
export_mime_type = (
"text/plain"
if mime_type != GDriveMimeType.SPREADSHEET.value
else "text/csv"
)
return (
service.files()
.export(fileId=file["id"], mimeType=export_mime_type)
.execute()
.decode("utf-8")
)
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,
]:
return service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
if mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
GDriveMimeType.PDF.value,
]:
response = service.files().get_media(fileId=file["id"]).execute()
if get_unstructured_api_key():
return unstructured_to_text(
file=io.BytesIO(response), file_name=file.get("name", file["id"])
)
def _extract_docs_from_google_drive(
if mime_type == GDriveMimeType.WORD_DOC.value:
return docx_to_text(file=io.BytesIO(response))
elif mime_type == GDriveMimeType.PDF.value:
text, _ = read_pdf_file(file=io.BytesIO(response))
return text
elif mime_type == GDriveMimeType.POWERPOINT.value:
return pptx_to_text(file=io.BytesIO(response))
return UNSUPPORTED_FILE_TYPE_CONTENT
class GoogleDriveConnector(LoadConnector, PollConnector):
def __init__(
self,
# optional list of folder paths e.g. "[My Folder/My Subfolder]"
# if specified, will only index files in these folders
folder_paths: list[str] | None = None,
batch_size: int = INDEX_BATCH_SIZE,
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
follow_shortcuts: bool = GOOGLE_DRIVE_FOLLOW_SHORTCUTS,
only_org_public: bool = GOOGLE_DRIVE_ONLY_ORG_PUBLIC,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
) -> None:
self.folder_paths = folder_paths or []
self.batch_size = batch_size
self.include_shared = include_shared
self.follow_shortcuts = follow_shortcuts
self.only_org_public = only_org_public
self.continue_on_failure = continue_on_failure
self.creds: OAuthCredentials | ServiceAccountCredentials | None = None
@staticmethod
def _process_folder_paths(
service: discovery.Resource,
folder_paths: list[str],
include_shared: bool,
follow_shortcuts: bool,
) -> list[str]:
"""['Folder/Sub Folder'] -> ['<FOLDER_ID>']"""
folder_ids: list[str] = []
for path in folder_paths:
folder_names = path.split("/")
parent_id = "root"
for folder_name in folder_names:
found_parent_id = _get_folder_id(
service=service,
parent_id=parent_id,
folder_name=folder_name,
include_shared=include_shared,
follow_shortcuts=follow_shortcuts,
)
if found_parent_id is None:
raise ValueError(
(
f"Folder '{folder_name}' in path '{path}' "
"not found in Google Drive"
)
)
parent_id = found_parent_id
folder_ids.append(parent_id)
return folder_ids
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
"""Checks for two different types of credentials.
(1) A credential which holds a token acquired via a user going thorough
the Google OAuth flow.
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.
"""
creds, new_creds_dict = get_google_drive_creds(credentials)
self.creds = creds
return new_creds_dict
def _fetch_docs_from_drive(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
# Create a larger process pool for file conversion
convert_func = partial(
_convert_single_file, self.creds, self.primary_admin_email
if self.creds is None:
raise PermissionError("Not logged into Google Drive")
service = discovery.build("drive", "v3", credentials=self.creds)
folder_ids: Sequence[str | None] = self._process_folder_paths(
service, self.folder_paths, self.include_shared, self.follow_shortcuts
)
if not folder_ids:
folder_ids = [None]
# Process files in larger batches
LARGE_BATCH_SIZE = self.batch_size * 4
files_to_process = []
# Gather the files into batches to be processed in parallel
for file in self._fetch_drive_items(is_slim=False, start=start, end=end):
files_to_process.append(file)
if len(files_to_process) >= LARGE_BATCH_SIZE:
yield from _process_files_batch(
files_to_process, convert_func, self.batch_size
file_batches = chain(
*[
get_all_files_batched(
service=service,
continue_on_failure=self.continue_on_failure,
include_shared=self.include_shared,
follow_shortcuts=self.follow_shortcuts,
batch_size=self.batch_size,
time_range_start=start,
time_range_end=end,
folder_id=folder_id,
traverse_subfolders=True,
)
files_to_process = []
for folder_id in folder_ids
]
)
for files_batch in file_batches:
doc_batch = []
for file in files_batch:
try:
# Skip files that are shortcuts
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
logger.info("Ignoring Drive Shortcut Filetype")
continue
# Process any remaining files
if files_to_process:
yield from _process_files_batch(
files_to_process, convert_func, self.batch_size
)
if self.only_org_public:
if "permissions" not in file:
continue
if not any(
permission["type"] == "domain"
for permission in file["permissions"]
):
continue
try:
text_contents = extract_text(file, service) or ""
except HttpError as e:
reason = (
e.error_details[0]["reason"]
if e.error_details
else e.reason
)
message = (
e.error_details[0]["message"]
if e.error_details
else e.reason
)
# these errors don't represent a failure in the connector, but simply files
# that can't / shouldn't be indexed
ERRORS_TO_CONTINUE_ON = [
"cannotExportFile",
"exportSizeLimitExceeded",
"cannotDownloadFile",
]
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:
logger.warning(
f"Could not export file '{file['name']}' due to '{message}', skipping..."
)
continue
raise
doc_batch.append(
Document(
id=file["webViewLink"],
sections=[
Section(link=file["webViewLink"], text=text_contents)
],
source=DocumentSource.GOOGLE_DRIVE,
semantic_identifier=file["name"],
doc_updated_at=datetime.fromisoformat(
file["modifiedTime"]
).astimezone(timezone.utc),
metadata={} if text_contents else {IGNORE_FOR_QA: "True"},
additional_info=file.get("id"),
)
)
except Exception as e:
if not self.continue_on_failure:
raise e
logger.exception(
"Ran into exception when pulling a file from Google Drive"
)
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
try:
yield from self._extract_docs_from_google_drive()
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
yield from self._fetch_docs_from_drive()
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
try:
yield from self._extract_docs_from_google_drive(start, end)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
# need to subtract 10 minutes from start time to account for modifiedTime
# propogation if a document is modified, it takes some time for the API to
# reflect these changes if we do not have an offset, then we may "miss" the
# update when polling
yield from self._fetch_docs_from_drive(start, end)
def _extract_slim_docs_from_google_drive(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
slim_batch = []
for file in self._fetch_drive_items(
is_slim=True,
start=start,
end=end,
):
if doc := build_slim_document(file):
slim_batch.append(doc)
if len(slim_batch) >= SLIM_BATCH_SIZE:
yield slim_batch
slim_batch = []
yield slim_batch
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
try:
yield from self._extract_slim_docs_from_google_drive(start, end)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
if __name__ == "__main__":
import json
import os
service_account_json_path = os.environ.get("GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH")
if not service_account_json_path:
raise ValueError(
"Please set GOOGLE_SERVICE_ACCOUNT_KEY_JSON_PATH environment variable"
)
with open(service_account_json_path) as f:
creds = json.load(f)
credentials_dict = {
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: json.dumps(creds),
}
delegated_user = os.environ.get("GOOGLE_DRIVE_DELEGATED_USER")
if delegated_user:
credentials_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user
connector = GoogleDriveConnector(include_shared=True, follow_shortcuts=True)
connector.load_credentials(credentials_dict)
document_batch_generator = connector.load_from_state()
for document_batch in document_batch_generator:
print(document_batch)
break

View File

@@ -0,0 +1,229 @@
import json
from typing import cast
from urllib.parse import parse_qs
from urllib.parse import ParseResult
from urllib.parse import urlparse
from google.auth.transport.requests import Request # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from sqlalchemy.orm import Session
from danswer.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import KV_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from danswer.connectors.google_drive.constants import BASE_SCOPES
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
from danswer.db.credentials import update_credential_json
from danswer.db.models import User
from danswer.key_value_store.factory import get_kv_store
from danswer.server.documents.models import CredentialBase
from danswer.server.documents.models import GoogleAppCredentials
from danswer.server.documents.models import GoogleServiceAccountKey
from danswer.utils.logger import setup_logger
logger = setup_logger()
def build_gdrive_scopes() -> list[str]:
base_scopes: list[str] = BASE_SCOPES
permissions_scopes: list[str] = FETCH_PERMISSIONS_SCOPES
groups_scopes: list[str] = FETCH_GROUPS_SCOPES
if ENTERPRISE_EDITION_ENABLED:
return base_scopes + permissions_scopes + groups_scopes
return base_scopes + permissions_scopes
def _build_frontend_google_drive_redirect() -> str:
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
def get_google_drive_creds_for_authorized_user(
token_json_str: str, scopes: list[str] = build_gdrive_scopes()
) -> OAuthCredentials | None:
creds_json = json.loads(token_json_str)
creds = OAuthCredentials.from_authorized_user_info(creds_json, scopes)
if creds.valid:
return creds
if creds.expired and creds.refresh_token:
try:
creds.refresh(Request())
if creds.valid:
logger.notice("Refreshed Google Drive tokens.")
return creds
except Exception as e:
logger.exception(f"Failed to refresh google drive access token due to: {e}")
return None
return None
def _get_google_drive_creds_for_service_account(
service_account_key_json_str: str, scopes: list[str] = build_gdrive_scopes()
) -> ServiceAccountCredentials | None:
service_account_key = json.loads(service_account_key_json_str)
creds = ServiceAccountCredentials.from_service_account_info(
service_account_key, scopes=scopes
)
if not creds.valid or not creds.expired:
creds.refresh(Request())
return creds if creds.valid else None
def get_google_drive_creds(
credentials: dict[str, str], scopes: list[str] = build_gdrive_scopes()
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
oauth_creds = None
service_creds = None
new_creds_dict = None
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY])
oauth_creds = get_google_drive_creds_for_authorized_user(
token_json_str=access_token_json_str, scopes=scopes
)
# tell caller to update token stored in DB if it has changed
# (e.g. the token has been refreshed)
new_creds_json_str = oauth_creds.to_json() if oauth_creds else ""
if new_creds_json_str != access_token_json_str:
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
service_account_key_json_str = credentials[
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
]
service_creds = _get_google_drive_creds_for_service_account(
service_account_key_json_str=service_account_key_json_str,
scopes=scopes,
)
# "Impersonate" a user if one is specified
delegated_user_email = cast(
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
)
if delegated_user_email:
service_creds = (
service_creds.with_subject(delegated_user_email)
if service_creds
else None
)
creds: ServiceAccountCredentials | OAuthCredentials | None = (
oauth_creds or service_creds
)
if creds is None:
raise PermissionError(
"Unable to access Google Drive - unknown credential structure."
)
return creds, new_creds_dict
def verify_csrf(credential_id: int, state: str) -> None:
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
if csrf != state:
raise PermissionError(
"State from Google Drive Connector callback does not match expected"
)
def get_auth_url(credential_id: int) -> str:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
scopes=build_gdrive_scopes(),
redirect_uri=_build_frontend_google_drive_redirect(),
)
auth_url, _ = flow.authorization_url(prompt="consent")
parsed_url = cast(ParseResult, urlparse(auth_url))
params = parse_qs(parsed_url.query)
get_kv_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
return str(auth_url)
def update_credential_access_tokens(
auth_code: str,
credential_id: int,
user: User,
db_session: Session,
) -> OAuthCredentials | None:
app_credentials = get_google_app_cred()
flow = InstalledAppFlow.from_client_config(
app_credentials.model_dump(),
scopes=build_gdrive_scopes(),
redirect_uri=_build_frontend_google_drive_redirect(),
)
flow.fetch_token(code=auth_code)
creds = flow.credentials
token_json_str = creds.to_json()
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str}
if not update_credential_json(credential_id, new_creds_dict, user, db_session):
return None
return creds
def build_service_account_creds(
source: DocumentSource,
delegated_user_email: str | None = None,
) -> CredentialBase:
service_account_key = get_service_account_key()
credential_dict = {
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: service_account_key.json(),
}
if delegated_user_email:
credential_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user_email
return CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=DocumentSource.GOOGLE_DRIVE,
)
def get_google_app_cred() -> GoogleAppCredentials:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
return GoogleAppCredentials(**json.loads(creds_str))
def upsert_google_app_cred(app_credentials: GoogleAppCredentials) -> None:
get_kv_store().store(KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True)
def delete_google_app_cred() -> None:
get_kv_store().delete(KV_GOOGLE_DRIVE_CRED_KEY)
def get_service_account_key() -> GoogleServiceAccountKey:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
return GoogleServiceAccountKey(**json.loads(creds_str))
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
get_kv_store().store(
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
def delete_service_account_key() -> None:
get_kv_store().delete(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)

View File

@@ -1,4 +1,7 @@
UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now
DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder"
DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut"
DRIVE_FILE_TYPE = "application/vnd.google-apps.file"
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens"
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user"
BASE_SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
FETCH_PERMISSIONS_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"]
FETCH_GROUPS_SCOPES = ["https://www.googleapis.com/auth/cloud-identity.groups.readonly"]

View File

@@ -1,197 +0,0 @@
import io
from datetime import datetime
from datetime import timezone
from googleapiclient.errors import HttpError # type: ignore
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
from danswer.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
from danswer.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT
from danswer.connectors.google_drive.models import GDriveMimeType
from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.connectors.google_drive.section_extraction import get_document_sections
from danswer.connectors.google_utils.resources import GoogleDocsService
from danswer.connectors.google_utils.resources import GoogleDriveService
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.models import SlimDocument
from danswer.file_processing.extract_file_text import docx_to_text
from danswer.file_processing.extract_file_text import pptx_to_text
from danswer.file_processing.extract_file_text import read_pdf_file
from danswer.file_processing.unstructured import get_unstructured_api_key
from danswer.file_processing.unstructured import unstructured_to_text
from danswer.utils.logger import setup_logger
logger = setup_logger()
# these errors don't represent a failure in the connector, but simply files
# that can't / shouldn't be indexed
ERRORS_TO_CONTINUE_ON = [
"cannotExportFile",
"exportSizeLimitExceeded",
"cannotDownloadFile",
]
def _extract_sections_basic(
file: dict[str, str], service: GoogleDriveService
) -> list[Section]:
mime_type = file["mimeType"]
link = file["webViewLink"]
if mime_type not in set(item.value for item in GDriveMimeType):
# Unsupported file types can still have a title, finding this way is still useful
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
try:
if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
GDriveMimeType.SPREADSHEET.value,
]:
export_mime_type = (
"text/plain"
if mime_type != GDriveMimeType.SPREADSHEET.value
else "text/csv"
)
text = (
service.files()
.export(fileId=file["id"], mimeType=export_mime_type)
.execute()
.decode("utf-8")
)
return [Section(link=link, text=text)]
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,
]:
return [
Section(
link=link,
text=service.files()
.get_media(fileId=file["id"])
.execute()
.decode("utf-8"),
)
]
if mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
GDriveMimeType.PDF.value,
]:
response = service.files().get_media(fileId=file["id"]).execute()
if get_unstructured_api_key():
return [
Section(
link=link,
text=unstructured_to_text(
file=io.BytesIO(response),
file_name=file.get("name", file["id"]),
),
)
]
if mime_type == GDriveMimeType.WORD_DOC.value:
return [
Section(link=link, text=docx_to_text(file=io.BytesIO(response)))
]
elif mime_type == GDriveMimeType.PDF.value:
text, _ = read_pdf_file(file=io.BytesIO(response))
return [Section(link=link, text=text)]
elif mime_type == GDriveMimeType.POWERPOINT.value:
return [
Section(link=link, text=pptx_to_text(file=io.BytesIO(response)))
]
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
except Exception:
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
def convert_drive_item_to_document(
file: GoogleDriveFileType,
drive_service: GoogleDriveService,
docs_service: GoogleDocsService,
) -> Document | None:
try:
# Skip files that are shortcuts
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
logger.info("Ignoring Drive Shortcut Filetype")
return None
# Skip files that are folders
if file.get("mimeType") == DRIVE_FOLDER_TYPE:
logger.info("Ignoring Drive Folder Filetype")
return None
sections: list[Section] = []
# Special handling for Google Docs to preserve structure, link
# to headers
if file.get("mimeType") == GDriveMimeType.DOC.value:
try:
sections = get_document_sections(docs_service, file["id"])
except Exception as e:
logger.warning(
f"Ran into exception '{e}' when pulling sections from Google Doc '{file['name']}'."
" Falling back to basic extraction."
)
# NOTE: this will run for either (1) the above failed or (2) the file is not a Google Doc
if not sections:
try:
# For all other file types just extract the text
sections = _extract_sections_basic(file, drive_service)
except HttpError as e:
reason = e.error_details[0]["reason"] if e.error_details else e.reason
message = e.error_details[0]["message"] if e.error_details else e.reason
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:
logger.warning(
f"Could not export file '{file['name']}' due to '{message}', skipping..."
)
return None
raise
if not sections:
return None
return Document(
id=file["webViewLink"],
sections=sections,
source=DocumentSource.GOOGLE_DRIVE,
semantic_identifier=file["name"],
doc_updated_at=datetime.fromisoformat(file["modifiedTime"]).astimezone(
timezone.utc
),
metadata={}
if any(section.text for section in sections)
else {IGNORE_FOR_QA: "True"},
additional_info=file.get("id"),
)
except Exception as e:
if not CONTINUE_ON_CONNECTOR_FAILURE:
raise e
logger.exception("Ran into exception when pulling a file from Google Drive")
return None
def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None:
# Skip files that are folders or shortcuts
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
return None
return SlimDocument(
id=file["webViewLink"],
perm_sync_data={
"doc_id": file.get("id"),
"permissions": file.get("permissions", []),
"permission_ids": file.get("permissionIds", []),
"name": file.get("name"),
"owner_email": file.get("owners", [{}])[0].get("emailAddress"),
},
)

View File

@@ -1,222 +0,0 @@
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from typing import Any
from googleapiclient.discovery import Resource # type: ignore
from danswer.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
from danswer.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.connectors.google_utils.google_utils import execute_paginated_retrieval
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.utils.logger import setup_logger
logger = setup_logger()
FILE_FIELDS = (
"nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, "
"shortcutDetails, owners(emailAddress))"
)
SLIM_FILE_FIELDS = (
"nextPageToken, files(mimeType, id, name, permissions(emailAddress, type), "
"permissionIds, webViewLink, owners(emailAddress))"
)
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
def _generate_time_range_filter(
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> str:
time_range_filter = ""
if start is not None:
time_start = datetime.utcfromtimestamp(start).isoformat() + "Z"
time_range_filter += f" and modifiedTime >= '{time_start}'"
if end is not None:
time_stop = datetime.utcfromtimestamp(end).isoformat() + "Z"
time_range_filter += f" and modifiedTime <= '{time_stop}'"
return time_range_filter
def _get_folders_in_parent(
service: Resource,
parent_id: str | None = None,
) -> Iterator[GoogleDriveFileType]:
# Follow shortcuts to folders
query = f"(mimeType = '{DRIVE_FOLDER_TYPE}' or mimeType = '{DRIVE_SHORTCUT_TYPE}')"
query += " and trashed = false"
if parent_id:
query += f" and '{parent_id}' in parents"
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
corpora="allDrives",
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields=FOLDER_FIELDS,
q=query,
):
yield file
def _get_files_in_parent(
service: Resource,
parent_id: str,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
is_slim: bool = False,
) -> Iterator[GoogleDriveFileType]:
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents"
query += " and trashed = false"
query += _generate_time_range_filter(start, end)
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
corpora="allDrives",
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
):
yield file
def crawl_folders_for_files(
service: Resource,
parent_id: str,
traversed_parent_ids: set[str],
update_traversed_ids_func: Callable[[str], None],
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
"""
This function starts crawling from any folder. It is slower though.
"""
if parent_id in traversed_parent_ids:
logger.info(f"Skipping subfolder since already traversed: {parent_id}")
return
found_files = False
for file in _get_files_in_parent(
service=service,
start=start,
end=end,
parent_id=parent_id,
):
found_files = True
yield file
if found_files:
update_traversed_ids_func(parent_id)
for subfolder in _get_folders_in_parent(
service=service,
parent_id=parent_id,
):
logger.info("Fetching all files in subfolder: " + subfolder["name"])
yield from crawl_folders_for_files(
service=service,
parent_id=subfolder["id"],
traversed_parent_ids=traversed_parent_ids,
update_traversed_ids_func=update_traversed_ids_func,
start=start,
end=end,
)
def get_files_in_shared_drive(
service: Resource,
drive_id: str,
is_slim: bool = False,
update_traversed_ids_func: Callable[[str], None] = lambda _: None,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
# If we know we are going to folder crawl later, we can cache the folders here
# Get all folders being queried and add them to the traversed set
query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
query += " and trashed = false"
found_folders = False
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
corpora="drive",
driveId=drive_id,
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields="nextPageToken, files(id)",
q=query,
):
update_traversed_ids_func(file["id"])
found_folders = True
if found_folders:
update_traversed_ids_func(drive_id)
# Get all files in the shared drive
query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
query += " and trashed = false"
query += _generate_time_range_filter(start, end)
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
corpora="drive",
driveId=drive_id,
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
)
def get_all_files_in_my_drive(
service: Any,
update_traversed_ids_func: Callable,
is_slim: bool = False,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
# If we know we are going to folder crawl later, we can cache the folders here
# Get all folders being queried and add them to the traversed set
query = "trashed = false and 'me' in owners"
found_folders = False
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="user",
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
):
update_traversed_ids_func(file["id"])
found_folders = True
if found_folders:
update_traversed_ids_func(get_root_folder_id(service))
# Then get the files
query = "trashed = false and 'me' in owners"
query += _generate_time_range_filter(start, end)
fields = "files(id, name, mimeType, webViewLink, modifiedTime, createdTime)"
if not is_slim:
fields += ", files(permissions, permissionIds, owners)"
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="user",
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
)
# Just in case we need to get the root folder id
def get_root_folder_id(service: Resource) -> str:
# we dont paginate here because there is only one root folder per user
# https://developers.google.com/drive/api/guides/v2-to-v3-reference
return service.files().get(fileId="root", fields="id").execute()["id"]

View File

@@ -1,18 +0,0 @@
from enum import Enum
from typing import Any
class GDriveMimeType(str, Enum):
DOC = "application/vnd.google-apps.document"
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
PDF = "application/pdf"
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
PPT = "application/vnd.google-apps.presentation"
POWERPOINT = (
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
)
PLAIN_TEXT = "text/plain"
MARKDOWN = "text/markdown"
GoogleDriveFileType = dict[str, Any]

View File

@@ -1,105 +0,0 @@
from typing import Any
from pydantic import BaseModel
from danswer.connectors.google_utils.resources import GoogleDocsService
from danswer.connectors.models import Section
class CurrentHeading(BaseModel):
id: str
text: str
def _build_gdoc_section_link(doc_id: str, heading_id: str) -> str:
"""Builds a Google Doc link that jumps to a specific heading"""
# NOTE: doesn't support docs with multiple tabs atm, if we need that ask
# @Chris
return (
f"https://docs.google.com/document/d/{doc_id}/edit?tab=t.0#heading={heading_id}"
)
def _extract_id_from_heading(paragraph: dict[str, Any]) -> str:
"""Extracts the id from a heading paragraph element"""
return paragraph["paragraphStyle"]["headingId"]
def _extract_text_from_paragraph(paragraph: dict[str, Any]) -> str:
"""Extracts the text content from a paragraph element"""
text_elements = []
for element in paragraph.get("elements", []):
if "textRun" in element:
text_elements.append(element["textRun"].get("content", ""))
return "".join(text_elements)
def get_document_sections(
docs_service: GoogleDocsService,
doc_id: str,
) -> list[Section]:
"""Extracts sections from a Google Doc, including their headings and content"""
# Fetch the document structure
doc = docs_service.documents().get(documentId=doc_id).execute()
# Get the content
content = doc.get("body", {}).get("content", [])
sections: list[Section] = []
current_section: list[str] = []
current_heading: CurrentHeading | None = None
for element in content:
if "paragraph" not in element:
continue
paragraph = element["paragraph"]
# Check if this is a heading
if (
"paragraphStyle" in paragraph
and "namedStyleType" in paragraph["paragraphStyle"]
):
style = paragraph["paragraphStyle"]["namedStyleType"]
is_heading = style.startswith("HEADING_")
is_title = style.startswith("TITLE")
if is_heading or is_title:
# If we were building a previous section, add it to sections list
if current_heading is not None and current_section:
heading_text = current_heading.text
section_text = f"{heading_text}\n" + "\n".join(current_section)
sections.append(
Section(
text=section_text.strip(),
link=_build_gdoc_section_link(doc_id, current_heading.id),
)
)
current_section = []
# Start new heading
heading_id = _extract_id_from_heading(paragraph)
heading_text = _extract_text_from_paragraph(paragraph)
current_heading = CurrentHeading(
id=heading_id,
text=heading_text,
)
continue
# Add content to current section
if current_heading is not None:
text = _extract_text_from_paragraph(paragraph)
if text.strip():
current_section.append(text)
# Don't forget to add the last section
if current_heading is not None and current_section:
section_text = f"{current_heading.text}\n" + "\n".join(current_section)
sections.append(
Section(
text=section_text.strip(),
link=_build_gdoc_section_link(doc_id, current_heading.id),
)
)
return sections

View File

@@ -1,107 +0,0 @@
import json
from typing import cast
from google.auth.transport.requests import Request # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from danswer.configs.constants import DocumentSource
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_TOKEN_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
GOOGLE_SCOPES,
)
from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_google_oauth_creds(
token_json_str: str, source: DocumentSource
) -> OAuthCredentials | None:
creds_json = json.loads(token_json_str)
creds = OAuthCredentials.from_authorized_user_info(
info=creds_json,
scopes=GOOGLE_SCOPES[source],
)
if creds.valid:
return creds
if creds.expired and creds.refresh_token:
try:
creds.refresh(Request())
if creds.valid:
logger.notice("Refreshed Google Drive tokens.")
return creds
except Exception:
logger.exception("Failed to refresh google drive access token due to:")
return None
return None
def get_google_creds(
credentials: dict[str, str],
source: DocumentSource,
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
"""Checks for two different types of credentials.
(1) A credential which holds a token acquired via a user going thorough
the Google OAuth flow.
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.
"""
oauth_creds = None
service_creds = None
new_creds_dict = None
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
# OAUTH
access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY])
oauth_creds = get_google_oauth_creds(
token_json_str=access_token_json_str, source=source
)
# tell caller to update token stored in DB if it has changed
# (e.g. the token has been refreshed)
new_creds_json_str = oauth_creds.to_json() if oauth_creds else ""
if new_creds_json_str != access_token_json_str:
new_creds_dict = {
DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[
DB_CREDENTIALS_PRIMARY_ADMIN_KEY
],
}
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
# SERVICE ACCOUNT
service_account_key_json_str = credentials[
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
]
service_account_key = json.loads(service_account_key_json_str)
service_creds = ServiceAccountCredentials.from_service_account_info(
service_account_key, scopes=GOOGLE_SCOPES[source]
)
if not service_creds.valid or not service_creds.expired:
service_creds.refresh(Request())
if not service_creds.valid:
raise PermissionError(
f"Unable to access {source} - service account credentials are invalid."
)
creds: ServiceAccountCredentials | OAuthCredentials | None = (
oauth_creds or service_creds
)
if creds is None:
raise PermissionError(
f"Unable to access {source} - unknown credential structure."
)
return creds, new_creds_dict

View File

@@ -1,237 +0,0 @@
import json
from typing import cast
from urllib.parse import parse_qs
from urllib.parse import ParseResult
from urllib.parse import urlparse
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from sqlalchemy.orm import Session
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import KV_CRED_KEY
from danswer.configs.constants import KV_GMAIL_CRED_KEY
from danswer.configs.constants import KV_GMAIL_SERVICE_ACCOUNT_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from danswer.connectors.google_utils.resources import get_drive_service
from danswer.connectors.google_utils.resources import get_gmail_service
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_TOKEN_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
GOOGLE_SCOPES,
)
from danswer.connectors.google_utils.shared_constants import (
MISSING_SCOPES_ERROR_STR,
)
from danswer.connectors.google_utils.shared_constants import (
ONYX_SCOPE_INSTRUCTIONS,
)
from danswer.db.credentials import update_credential_json
from danswer.db.models import User
from danswer.key_value_store.factory import get_kv_store
from danswer.server.documents.models import CredentialBase
from danswer.server.documents.models import GoogleAppCredentials
from danswer.server.documents.models import GoogleServiceAccountKey
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _build_frontend_google_drive_redirect(source: DocumentSource) -> str:
if source == DocumentSource.GOOGLE_DRIVE:
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
elif source == DocumentSource.GMAIL:
return f"{WEB_DOMAIN}/admin/connectors/gmail/auth/callback"
else:
raise ValueError(f"Unsupported source: {source}")
def _get_current_oauth_user(creds: OAuthCredentials, source: DocumentSource) -> str:
if source == DocumentSource.GOOGLE_DRIVE:
drive_service = get_drive_service(creds)
user_info = (
drive_service.about()
.get(
fields="user(emailAddress)",
)
.execute()
)
email = user_info.get("user", {}).get("emailAddress")
elif source == DocumentSource.GMAIL:
gmail_service = get_gmail_service(creds)
user_info = (
gmail_service.users()
.getProfile(
userId="me",
fields="emailAddress",
)
.execute()
)
email = user_info.get("emailAddress")
else:
raise ValueError(f"Unsupported source: {source}")
return email
def verify_csrf(credential_id: int, state: str) -> None:
csrf = get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))
if csrf != state:
raise PermissionError(
"State from Google Drive Connector callback does not match expected"
)
def update_credential_access_tokens(
auth_code: str,
credential_id: int,
user: User,
db_session: Session,
source: DocumentSource,
) -> OAuthCredentials | None:
app_credentials = get_google_app_cred(source)
flow = InstalledAppFlow.from_client_config(
app_credentials.model_dump(),
scopes=GOOGLE_SCOPES[source],
redirect_uri=_build_frontend_google_drive_redirect(source),
)
flow.fetch_token(code=auth_code)
creds = flow.credentials
token_json_str = creds.to_json()
# Get user email from Google API so we know who
# the primary admin is for this connector
try:
email = _get_current_oauth_user(creds, source)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
new_creds_dict = {
DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
}
if not update_credential_json(credential_id, new_creds_dict, user, db_session):
return None
return creds
def build_service_account_creds(
source: DocumentSource,
primary_admin_email: str | None = None,
) -> CredentialBase:
service_account_key = get_service_account_key(source=source)
credential_dict = {
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: service_account_key.json(),
}
if primary_admin_email:
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = primary_admin_email
return CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=source,
)
def get_auth_url(credential_id: int, source: DocumentSource) -> str:
if source == DocumentSource.GOOGLE_DRIVE:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
elif source == DocumentSource.GMAIL:
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
else:
raise ValueError(f"Unsupported source: {source}")
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
scopes=GOOGLE_SCOPES[source],
redirect_uri=_build_frontend_google_drive_redirect(source),
)
auth_url, _ = flow.authorization_url(prompt="consent")
parsed_url = cast(ParseResult, urlparse(auth_url))
params = parse_qs(parsed_url.query)
get_kv_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
return str(auth_url)
def get_google_app_cred(source: DocumentSource) -> GoogleAppCredentials:
if source == DocumentSource.GOOGLE_DRIVE:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
elif source == DocumentSource.GMAIL:
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
else:
raise ValueError(f"Unsupported source: {source}")
return GoogleAppCredentials(**json.loads(creds_str))
def upsert_google_app_cred(
app_credentials: GoogleAppCredentials, source: DocumentSource
) -> None:
if source == DocumentSource.GOOGLE_DRIVE:
get_kv_store().store(
KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
)
elif source == DocumentSource.GMAIL:
get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True)
else:
raise ValueError(f"Unsupported source: {source}")
def delete_google_app_cred(source: DocumentSource) -> None:
if source == DocumentSource.GOOGLE_DRIVE:
get_kv_store().delete(KV_GOOGLE_DRIVE_CRED_KEY)
elif source == DocumentSource.GMAIL:
get_kv_store().delete(KV_GMAIL_CRED_KEY)
else:
raise ValueError(f"Unsupported source: {source}")
def get_service_account_key(source: DocumentSource) -> GoogleServiceAccountKey:
if source == DocumentSource.GOOGLE_DRIVE:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
elif source == DocumentSource.GMAIL:
creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
else:
raise ValueError(f"Unsupported source: {source}")
return GoogleServiceAccountKey(**json.loads(creds_str))
def upsert_service_account_key(
service_account_key: GoogleServiceAccountKey, source: DocumentSource
) -> None:
if source == DocumentSource.GOOGLE_DRIVE:
get_kv_store().store(
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY,
service_account_key.json(),
encrypt=True,
)
elif source == DocumentSource.GMAIL:
get_kv_store().store(
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
else:
raise ValueError(f"Unsupported source: {source}")
def delete_service_account_key(source: DocumentSource) -> None:
if source == DocumentSource.GOOGLE_DRIVE:
get_kv_store().delete(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
elif source == DocumentSource.GMAIL:
get_kv_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
else:
raise ValueError(f"Unsupported source: {source}")

View File

@@ -1,125 +0,0 @@
import re
import time
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import Any
from googleapiclient.errors import HttpError # type: ignore
from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.utils.logger import setup_logger
from danswer.utils.retry_wrapper import retry_builder
logger = setup_logger()
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. Trying to combat here by adding a very
# long retry period (~20 minutes of trying every minute)
add_retries = retry_builder(tries=50, max_delay=30)
def _execute_with_retry(request: Any) -> Any:
max_attempts = 10
attempt = 1
while attempt < max_attempts:
# Note for reasons unknown, the Google API will sometimes return a 429
# and even after waiting the retry period, it will return another 429.
# It could be due to a few possibilities:
# 1. Other things are also requesting from the Gmail API with the same key
# 2. It's a rolling rate limit so the moment we get some amount of requests cleared, we hit it again very quickly
# 3. The retry-after has a maximum and we've already hit the limit for the day
# or it's something else...
try:
return request.execute()
except HttpError as error:
attempt += 1
if error.resp.status == 429:
# Attempt to get 'Retry-After' from headers
retry_after = error.resp.get("Retry-After")
if retry_after:
sleep_time = int(retry_after)
else:
# Extract 'Retry after' timestamp from error message
match = re.search(
r"Retry after (\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+Z)",
str(error),
)
if match:
retry_after_timestamp = match.group(1)
retry_after_dt = datetime.strptime(
retry_after_timestamp, "%Y-%m-%dT%H:%M:%S.%fZ"
).replace(tzinfo=timezone.utc)
current_time = datetime.now(timezone.utc)
sleep_time = max(
int((retry_after_dt - current_time).total_seconds()),
0,
)
else:
logger.error(
f"No Retry-After header or timestamp found in error message: {error}"
)
sleep_time = 60
sleep_time += 3 # Add a buffer to be safe
logger.info(
f"Rate limit exceeded. Attempt {attempt}/{max_attempts}. Sleeping for {sleep_time} seconds."
)
time.sleep(sleep_time)
else:
raise
# If we've exhausted all attempts
raise Exception(f"Failed to execute request after {max_attempts} attempts")
def execute_paginated_retrieval(
retrieval_function: Callable,
list_key: str | None = None,
continue_on_404_or_403: bool = False,
**kwargs: Any,
) -> Iterator[GoogleDriveFileType]:
"""Execute a paginated retrieval from Google Drive API
Args:
retrieval_function: The specific list function to call (e.g., service.files().list)
**kwargs: Arguments to pass to the list function
"""
next_page_token = ""
while next_page_token is not None:
request_kwargs = kwargs.copy()
if next_page_token:
request_kwargs["pageToken"] = next_page_token
try:
results = retrieval_function(**request_kwargs).execute()
except HttpError as e:
if e.resp.status >= 500:
results = add_retries(
lambda: retrieval_function(**request_kwargs).execute()
)()
elif e.resp.status == 404 or e.resp.status == 403:
if continue_on_404_or_403:
logger.debug(f"Error executing request: {e}")
results = {}
else:
raise e
elif e.resp.status == 429:
results = _execute_with_retry(
lambda: retrieval_function(**request_kwargs).execute()
)
else:
logger.exception("Error executing request:")
raise e
next_page_token = results.get("nextPageToken")
if list_key:
for item in results.get(list_key, []):
yield item
else:
yield results

View File

@@ -1,63 +0,0 @@
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.discovery import build # type: ignore
from googleapiclient.discovery import Resource # type: ignore
class GoogleDriveService(Resource):
pass
class GoogleDocsService(Resource):
pass
class AdminService(Resource):
pass
class GmailService(Resource):
pass
def _get_google_service(
service_name: str,
service_version: str,
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GoogleDriveService | GoogleDocsService | AdminService | GmailService:
if isinstance(creds, ServiceAccountCredentials):
creds = creds.with_subject(user_email)
service = build(service_name, service_version, credentials=creds)
elif isinstance(creds, OAuthCredentials):
service = build(service_name, service_version, credentials=creds)
return service
def get_google_docs_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GoogleDocsService:
return _get_google_service("docs", "v1", creds, user_email)
def get_drive_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GoogleDriveService:
return _get_google_service("drive", "v3", creds, user_email)
def get_admin_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> AdminService:
return _get_google_service("admin", "directory_v1", creds, user_email)
def get_gmail_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GmailService:
return _get_google_service("gmail", "v1", creds, user_email)

View File

@@ -1,40 +0,0 @@
from danswer.configs.constants import DocumentSource
# NOTE: do not need https://www.googleapis.com/auth/documents.readonly
# this is counted under `/auth/drive.readonly`
GOOGLE_SCOPES = {
DocumentSource.GOOGLE_DRIVE: [
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive.metadata.readonly",
"https://www.googleapis.com/auth/admin.directory.group.readonly",
"https://www.googleapis.com/auth/admin.directory.user.readonly",
],
DocumentSource.GMAIL: [
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/admin.directory.user.readonly",
"https://www.googleapis.com/auth/admin.directory.group.readonly",
],
}
# This is the Oauth token
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
# This is the service account key
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
# The email saved for both auth types
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
USER_FIELDS = "nextPageToken, users(primaryEmail)"
# Error message substrings
MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested"
# Documentation and error messages
SCOPE_DOC_URL = "https://docs.danswer.dev/connectors/google_drive/overview"
ONYX_SCOPE_INSTRUCTIONS = (
"You have upgraded Danswer without updating the Google Auth scopes. "
f"Please refer to the documentation to learn how to update the scopes: {SCOPE_DOC_URL}"
)
# This is the maximum number of threads that can be retrieved at once
SLIM_BATCH_SIZE = 500

View File

@@ -56,11 +56,7 @@ class PollConnector(BaseConnector):
class SlimConnector(BaseConnector):
@abc.abstractmethod
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
raise NotImplementedError

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import builtins
import functools
import itertools
import tempfile
from typing import Any
from unittest import mock
from urllib.parse import urlparse
@@ -19,8 +18,6 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
pywikibot.config.base_dir = tempfile.TemporaryDirectory().name
@mock.patch.object(
builtins, "print", lambda *args: logger.info("\t".join(map(str, args)))

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import datetime
import itertools
import tempfile
from collections.abc import Generator
from collections.abc import Iterator
from typing import Any
@@ -26,8 +25,6 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
pywikibot.config.base_dir = tempfile.TemporaryDirectory().name
def pywikibot_timestamp_to_utc_datetime(
timestamp: pywikibot.time.Timestamp,
@@ -124,6 +121,7 @@ class MediaWikiConnector(LoadConnector, PollConnector):
self.batch_size = batch_size
# short names can only have ascii letters and digits
self.family = family_class_dispatch(hostname, "WikipediaConnector")()
self.site = pywikibot.Site(fam=self.family, code=language_code)
self.categories = [

View File

@@ -251,11 +251,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
end_datetime = datetime.utcfromtimestamp(end)
return self._fetch_from_salesforce(start=start_datetime, end=end_datetime)
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
doc_metadata_list: list[SlimDocument] = []

View File

@@ -391,11 +391,7 @@ class SlackPollConnector(PollConnector, SlimConnector):
self.client = WebClient(token=bot_token)
return None
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
if self.client is None:
raise ConnectorMissingCredentialError("Slack")

View File

@@ -1,7 +1,10 @@
from collections.abc import Iterator
from typing import Any
import requests
from retry import retry
from zenpy import Zenpy # type: ignore
from zenpy.lib.api_objects import Ticket # type: ignore
from zenpy.lib.api_objects.help_centre_objects import Article # type: ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
@@ -17,244 +20,43 @@ from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.html_utils import parse_html_page_basic
from danswer.utils.retry_wrapper import retry_builder
MAX_PAGE_SIZE = 30 # Zendesk API maximum
class ZendeskCredentialsNotSetUpError(PermissionError):
def __init__(self) -> None:
super().__init__(
"Zendesk Credentials are not set up, was load_credentials called?"
)
class ZendeskClient:
def __init__(self, subdomain: str, email: str, token: str):
self.base_url = f"https://{subdomain}.zendesk.com/api/v2"
self.auth = (f"{email}/token", token)
@retry_builder()
def make_request(self, endpoint: str, params: dict[str, Any]) -> dict[str, Any]:
response = requests.get(
f"{self.base_url}/{endpoint}", auth=self.auth, params=params
)
response.raise_for_status()
return response.json()
def _get_content_tag_mapping(client: ZendeskClient) -> dict[str, str]:
content_tags: dict[str, str] = {}
params = {"page[size]": MAX_PAGE_SIZE}
try:
while True:
data = client.make_request("guide/content_tags", params)
for tag in data.get("records", []):
content_tags[tag["id"]] = tag["name"]
# Check if there are more pages
if data.get("meta", {}).get("has_more", False):
params["page[after]"] = data["meta"]["after_cursor"]
else:
break
return content_tags
except Exception as e:
raise Exception(f"Error fetching content tags: {str(e)}")
def _get_articles(
client: ZendeskClient, start_time: int | None = None, page_size: int = MAX_PAGE_SIZE
) -> Iterator[dict[str, Any]]:
params = (
{"start_time": start_time, "page[size]": page_size}
if start_time
else {"page[size]": page_size}
def _article_to_document(article: Article, content_tags: dict[str, str]) -> Document:
author = BasicExpertInfo(
display_name=article.author.name, email=article.author.email
)
update_time = time_str_to_utc(article.updated_at)
while True:
data = client.make_request("help_center/articles", params)
for article in data["articles"]:
yield article
if not data.get("meta", {}).get("has_more"):
break
params["page[after]"] = data["meta"]["after_cursor"]
def _get_tickets(
client: ZendeskClient, start_time: int | None = None
) -> Iterator[dict[str, Any]]:
params = {"start_time": start_time} if start_time else {"start_time": 0}
while True:
data = client.make_request("incremental/tickets.json", params)
for ticket in data["tickets"]:
yield ticket
if not data.get("end_of_stream", False):
params["start_time"] = data["end_time"]
else:
break
def _fetch_author(client: ZendeskClient, author_id: str) -> BasicExpertInfo | None:
author_data = client.make_request(f"users/{author_id}", {})
user = author_data.get("user")
return (
BasicExpertInfo(display_name=user.get("name"), email=user.get("email"))
if user and user.get("name") and user.get("email")
else None
)
def _article_to_document(
article: dict[str, Any],
content_tags: dict[str, str],
author_map: dict[str, BasicExpertInfo],
client: ZendeskClient,
) -> tuple[dict[str, BasicExpertInfo] | None, Document]:
author_id = article.get("author_id")
if not author_id:
author = None
else:
author = (
author_map.get(author_id)
if author_id in author_map
else _fetch_author(client, author_id)
)
new_author_mapping = {author_id: author} if author_id and author else None
updated_at = article.get("updated_at")
update_time = time_str_to_utc(updated_at) if updated_at else None
# Build metadata
# build metadata
metadata: dict[str, str | list[str]] = {
"labels": [str(label) for label in article.get("label_names", []) if label],
"labels": [str(label) for label in article.label_names if label],
"content_tags": [
content_tags[tag_id]
for tag_id in article.get("content_tag_ids", [])
for tag_id in article.content_tag_ids
if tag_id in content_tags
],
}
# Remove empty values
# remove empty values
metadata = {k: v for k, v in metadata.items() if v}
return new_author_mapping, Document(
id=f"article:{article['id']}",
return Document(
id=f"article:{article.id}",
sections=[
Section(
link=article.get("html_url"),
text=parse_html_page_basic(article["body"]),
)
Section(link=article.html_url, text=parse_html_page_basic(article.body))
],
source=DocumentSource.ZENDESK,
semantic_identifier=article["title"],
semantic_identifier=article.title,
doc_updated_at=update_time,
primary_owners=[author] if author else None,
primary_owners=[author],
metadata=metadata,
)
def _get_comment_text(
comment: dict[str, Any],
author_map: dict[str, BasicExpertInfo],
client: ZendeskClient,
) -> tuple[dict[str, BasicExpertInfo] | None, str]:
author_id = comment.get("author_id")
if not author_id:
author = None
else:
author = (
author_map.get(author_id)
if author_id in author_map
else _fetch_author(client, author_id)
)
new_author_mapping = {author_id: author} if author_id and author else None
comment_text = f"Comment{' by ' + author.display_name if author and author.display_name else ''}"
comment_text += f"{' at ' + comment['created_at'] if comment.get('created_at') else ''}:\n{comment['body']}"
return new_author_mapping, comment_text
def _ticket_to_document(
ticket: dict[str, Any],
author_map: dict[str, BasicExpertInfo],
client: ZendeskClient,
default_subdomain: str,
) -> tuple[dict[str, BasicExpertInfo] | None, Document]:
submitter_id = ticket.get("submitter")
if not submitter_id:
submitter = None
else:
submitter = (
author_map.get(submitter_id)
if submitter_id in author_map
else _fetch_author(client, submitter_id)
)
new_author_mapping = (
{submitter_id: submitter} if submitter_id and submitter else None
)
updated_at = ticket.get("updated_at")
update_time = time_str_to_utc(updated_at) if updated_at else None
metadata: dict[str, str | list[str]] = {}
if status := ticket.get("status"):
metadata["status"] = status
if priority := ticket.get("priority"):
metadata["priority"] = priority
if tags := ticket.get("tags"):
metadata["tags"] = tags
if ticket_type := ticket.get("type"):
metadata["ticket_type"] = ticket_type
# Fetch comments for the ticket
comments_data = client.make_request(f"tickets/{ticket.get('id')}/comments", {})
comments = comments_data.get("comments", [])
comment_texts = []
for comment in comments:
new_author_mapping, comment_text = _get_comment_text(
comment, author_map, client
)
if new_author_mapping:
author_map.update(new_author_mapping)
comment_texts.append(comment_text)
comments_text = "\n\n".join(comment_texts)
subject = ticket.get("subject")
full_text = f"Ticket Subject:\n{subject}\n\nComments:\n{comments_text}"
ticket_url = ticket.get("url")
subdomain = (
ticket_url.split("//")[1].split(".zendesk.com")[0]
if ticket_url
else default_subdomain
)
ticket_display_url = (
f"https://{subdomain}.zendesk.com/agent/tickets/{ticket.get('id')}"
)
return new_author_mapping, Document(
id=f"zendesk_ticket_{ticket['id']}",
sections=[Section(link=ticket_display_url, text=full_text)],
source=DocumentSource.ZENDESK,
semantic_identifier=f"Ticket #{ticket['id']}: {subject or 'No Subject'}",
doc_updated_at=update_time,
primary_owners=[submitter] if submitter else None,
metadata=metadata,
)
class ZendeskClientNotSetUpError(PermissionError):
def __init__(self) -> None:
super().__init__("Zendesk Client is not set up, was load_credentials called?")
class ZendeskConnector(LoadConnector, PollConnector):
@@ -264,10 +66,44 @@ class ZendeskConnector(LoadConnector, PollConnector):
content_type: str = "articles",
) -> None:
self.batch_size = batch_size
self.content_type = content_type
self.subdomain = ""
# Fetch all tags ahead of time
self.zendesk_client: Zenpy | None = None
self.content_tags: dict[str, str] = {}
self.content_type = content_type
@retry(tries=3, delay=2, backoff=2)
def _set_content_tags(
self, subdomain: str, email: str, token: str, page_size: int = 30
) -> None:
# Construct the base URL
base_url = f"https://{subdomain}.zendesk.com/api/v2/guide/content_tags"
# Set up authentication
auth = (f"{email}/token", token)
# Set up pagination parameters
params = {"page[size]": page_size}
try:
while True:
# Make the GET request
response = requests.get(base_url, auth=auth, params=params)
# Check if the request was successful
if response.status_code == 200:
data = response.json()
content_tag_list = data.get("records", [])
for tag in content_tag_list:
self.content_tags[tag["id"]] = tag["name"]
# Check if there are more pages
if data.get("meta", {}).get("has_more", False):
params["page[after]"] = data["meta"]["after_cursor"]
else:
break
else:
raise Exception(f"Error: {response.status_code}\n{response.text}")
except Exception as e:
raise Exception(f"Error fetching content tags: {str(e)}")
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# Subdomain is actually the whole URL
@@ -276,23 +112,87 @@ class ZendeskConnector(LoadConnector, PollConnector):
.replace("https://", "")
.split(".zendesk.com")[0]
)
self.subdomain = subdomain
self.client = ZendeskClient(
subdomain, credentials["zendesk_email"], credentials["zendesk_token"]
self.zendesk_client = Zenpy(
subdomain=subdomain,
email=credentials["zendesk_email"],
token=credentials["zendesk_token"],
)
self._set_content_tags(
subdomain,
credentials["zendesk_email"],
credentials["zendesk_token"],
)
return None
def load_from_state(self) -> GenerateDocumentsOutput:
return self.poll_source(None, None)
def _ticket_to_document(self, ticket: Ticket) -> Document:
if self.zendesk_client is None:
raise ZendeskClientNotSetUpError()
owner = None
if ticket.requester and ticket.requester.name and ticket.requester.email:
owner = [
BasicExpertInfo(
display_name=ticket.requester.name, email=ticket.requester.email
)
]
update_time = time_str_to_utc(ticket.updated_at) if ticket.updated_at else None
metadata: dict[str, str | list[str]] = {}
if ticket.status is not None:
metadata["status"] = ticket.status
if ticket.priority is not None:
metadata["priority"] = ticket.priority
if ticket.tags:
metadata["tags"] = ticket.tags
if ticket.type is not None:
metadata["ticket_type"] = ticket.type
# Fetch comments for the ticket
comments = self.zendesk_client.tickets.comments(ticket=ticket)
# Combine all comments into a single text
comments_text = "\n\n".join(
[
f"Comment{f' by {comment.author.name}' if comment.author and comment.author.name else ''}"
f"{f' at {comment.created_at}' if comment.created_at else ''}:\n{comment.body}"
for comment in comments
if comment.body
]
)
# Combine ticket description and comments
description = (
ticket.description
if hasattr(ticket, "description") and ticket.description
else ""
)
full_text = f"Ticket Description:\n{description}\n\nComments:\n{comments_text}"
# Extract subdomain from ticket.url
subdomain = ticket.url.split("//")[1].split(".zendesk.com")[0]
# Build the html url for the ticket
ticket_url = f"https://{subdomain}.zendesk.com/agent/tickets/{ticket.id}"
return Document(
id=f"zendesk_ticket_{ticket.id}",
sections=[Section(link=ticket_url, text=full_text)],
source=DocumentSource.ZENDESK,
semantic_identifier=f"Ticket #{ticket.id}: {ticket.subject or 'No Subject'}",
doc_updated_at=update_time,
primary_owners=owner,
metadata=metadata,
)
def poll_source(
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
if self.client is None:
raise ZendeskCredentialsNotSetUpError()
self.content_tags = _get_content_tag_mapping(self.client)
if self.zendesk_client is None:
raise ZendeskClientNotSetUpError()
if self.content_type == "articles":
yield from self._poll_articles(start)
@@ -304,30 +204,26 @@ class ZendeskConnector(LoadConnector, PollConnector):
def _poll_articles(
self, start: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
articles = _get_articles(self.client, start_time=int(start) if start else None)
# This one is built on the fly as there may be more many more authors than tags
author_map: dict[str, BasicExpertInfo] = {}
articles = (
self.zendesk_client.help_center.articles(cursor_pagination=True) # type: ignore
if start is None
else self.zendesk_client.help_center.articles.incremental( # type: ignore
start_time=int(start)
)
)
doc_batch = []
for article in articles:
if (
article.get("body") is None
or article.get("draft")
article.body is None
or article.draft
or any(
label in ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
for label in article.get("label_names", [])
for label in article.label_names
)
):
continue
new_author_map, documents = _article_to_document(
article, self.content_tags, author_map, self.client
)
if new_author_map:
author_map.update(new_author_map)
doc_batch.append(documents)
doc_batch.append(_article_to_document(article, self.content_tags))
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch.clear()
@@ -338,14 +234,10 @@ class ZendeskConnector(LoadConnector, PollConnector):
def _poll_tickets(
self, start: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
if self.client is None:
raise ZendeskCredentialsNotSetUpError()
if self.zendesk_client is None:
raise ZendeskClientNotSetUpError()
author_map: dict[str, BasicExpertInfo] = {}
ticket_generator = _get_tickets(
self.client, start_time=int(start) if start else None
)
ticket_generator = self.zendesk_client.tickets.incremental(start_time=start)
while True:
doc_batch = []
@@ -354,20 +246,10 @@ class ZendeskConnector(LoadConnector, PollConnector):
ticket = next(ticket_generator)
# Check if the ticket status is deleted and skip it if so
if ticket.get("status") == "deleted":
if ticket.status == "deleted":
continue
new_author_map, documents = _ticket_to_document(
ticket=ticket,
author_map=author_map,
client=self.client,
default_subdomain=self.subdomain,
)
if new_author_map:
author_map.update(new_author_map)
doc_batch.append(documents)
doc_batch.append(self._ticket_to_document(ticket))
if len(doc_batch) >= self.batch_size:
yield doc_batch
@@ -385,6 +267,7 @@ class ZendeskConnector(LoadConnector, PollConnector):
if __name__ == "__main__":
import os
import time
connector = ZendeskConnector()

View File

@@ -1,5 +1,3 @@
import os
from sqlalchemy.orm import Session
from danswer.db.models import SlackBotConfig
@@ -50,16 +48,3 @@ def validate_channel_names(
)
return cleaned_channel_names
# Scaling configurations for multi-tenant Slack bot handling
TENANT_LOCK_EXPIRATION = 1800 # How long a pod can hold exclusive access to a tenant before other pods can acquire it
TENANT_HEARTBEAT_INTERVAL = (
15 # How often pods send heartbeats to indicate they are still processing a tenant
)
TENANT_HEARTBEAT_EXPIRATION = (
30 # How long before a tenant's heartbeat expires, allowing other pods to take over
)
TENANT_ACQUISITION_INTERVAL = 60 # How often pods attempt to acquire unprocessed tenants and checks for new tokens
MAX_TENANTS_PER_POD = int(os.getenv("MAX_TENANTS_PER_POD", 50))

View File

@@ -1,36 +1,18 @@
import asyncio
import os
import signal
import sys
import threading
import time
from threading import Event
from types import FrameType
from typing import Any
from typing import cast
from typing import Dict
from typing import Set
from prometheus_client import Gauge
from prometheus_client import start_http_server
from slack_sdk import WebClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from danswer.configs.app_configs import POD_NAME
from danswer.configs.app_configs import POD_NAMESPACE
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
from danswer.connectors.slack.utils import expert_info_from_slack_id
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
from danswer.danswerbot.slack.config import MAX_TENANTS_PER_POD
from danswer.danswerbot.slack.config import TENANT_ACQUISITION_INTERVAL
from danswer.danswerbot.slack.config import TENANT_HEARTBEAT_EXPIRATION
from danswer.danswerbot.slack.config import TENANT_HEARTBEAT_INTERVAL
from danswer.danswerbot.slack.config import TENANT_LOCK_EXPIRATION
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
@@ -64,7 +46,6 @@ from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
from danswer.danswerbot.slack.utils import rephrase_slack_message
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import TenantSocketModeClient
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import get_session_with_tenant
from danswer.db.search_settings import get_current_search_settings
@@ -72,26 +53,17 @@ from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.one_shot_answer.models import ThreadMessage
from danswer.redis.redis_pool import get_redis_client
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import DISALLOWED_SLACK_BOT_TENANT_LIST
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SLACK_CHANNEL_ID
logger = setup_logger()
# Prometheus metric for HPA
active_tenants_gauge = Gauge(
"active_tenants",
"Number of active tenants handled by this pod",
["namespace", "pod"],
)
# In rare cases, some users have been experiencing a massive amount of trivial messages coming through
# to the Slack Bot with trivial messages. Adding this to avoid exploding LLM costs while we track down
# the cause.
@@ -105,234 +77,10 @@ _SLACK_GREETINGS_TO_IGNORE = {
":wave:",
}
# This is always (currently) the user id of Slack's official slackbot
# this is always (currently) the user id of Slack's official slackbot
_OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
class SlackbotHandler:
def __init__(self) -> None:
logger.info("Initializing SlackbotHandler")
self.tenant_ids: Set[str | None] = set()
self.socket_clients: Dict[str | None, TenantSocketModeClient] = {}
self.slack_bot_tokens: Dict[str | None, SlackBotTokens] = {}
self.running = True
self.pod_id = self.get_pod_id()
self._shutdown_event = Event()
logger.info(f"Pod ID: {self.pod_id}")
# Set up signal handlers for graceful shutdown
signal.signal(signal.SIGTERM, self.shutdown)
signal.signal(signal.SIGINT, self.shutdown)
logger.info("Signal handlers registered")
# Start the Prometheus metrics server
logger.info("Starting Prometheus metrics server")
start_http_server(8000)
logger.info("Prometheus metrics server started")
# Start background threads
logger.info("Starting background threads")
self.acquire_thread = threading.Thread(
target=self.acquire_tenants_loop, daemon=True
)
self.heartbeat_thread = threading.Thread(
target=self.heartbeat_loop, daemon=True
)
self.acquire_thread.start()
self.heartbeat_thread.start()
logger.info("Background threads started")
def get_pod_id(self) -> str:
pod_id = os.environ.get("HOSTNAME", "unknown_pod")
logger.info(f"Retrieved pod ID: {pod_id}")
return pod_id
def acquire_tenants_loop(self) -> None:
while not self._shutdown_event.is_set():
try:
self.acquire_tenants()
active_tenants_gauge.labels(namespace=POD_NAMESPACE, pod=POD_NAME).set(
len(self.tenant_ids)
)
logger.debug(f"Current active tenants: {len(self.tenant_ids)}")
except Exception as e:
logger.exception(f"Error in Slack acquisition: {e}")
self._shutdown_event.wait(timeout=TENANT_ACQUISITION_INTERVAL)
def heartbeat_loop(self) -> None:
while not self._shutdown_event.is_set():
try:
self.send_heartbeats()
logger.debug(f"Sent heartbeats for {len(self.tenant_ids)} tenants")
except Exception as e:
logger.exception(f"Error in heartbeat loop: {e}")
self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL)
def acquire_tenants(self) -> None:
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
if (
DISALLOWED_SLACK_BOT_TENANT_LIST is not None
and tenant_id in DISALLOWED_SLACK_BOT_TENANT_LIST
):
logger.debug(f"Tenant {tenant_id} is in the disallowed list, skipping")
continue
if tenant_id in self.tenant_ids:
logger.debug(f"Tenant {tenant_id} already in self.tenant_ids")
continue
if len(self.tenant_ids) >= MAX_TENANTS_PER_POD:
logger.info(
f"Max tenants per pod reached ({MAX_TENANTS_PER_POD}) Not acquiring any more tenants"
)
break
redis_client = get_redis_client(tenant_id=tenant_id)
pod_id = self.pod_id
acquired = redis_client.set(
DanswerRedisLocks.SLACK_BOT_LOCK,
pod_id,
nx=True,
ex=TENANT_LOCK_EXPIRATION,
)
if not acquired:
logger.debug(f"Another pod holds the lock for tenant {tenant_id}")
continue
logger.debug(f"Acquired lock for tenant {tenant_id}")
self.tenant_ids.add(tenant_id)
for tenant_id in self.tenant_ids:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
tenant_id or POSTGRES_DEFAULT_SCHEMA
)
try:
with get_session_with_tenant(tenant_id) as db_session:
try:
logger.debug(
f"Setting tenant ID context variable for tenant {tenant_id}"
)
slack_bot_tokens = fetch_tokens()
logger.debug(f"Fetched Slack bot tokens for tenant {tenant_id}")
logger.debug(
f"Reset tenant ID context variable for tenant {tenant_id}"
)
if not slack_bot_tokens:
logger.debug(
f"No Slack bot token found for tenant {tenant_id}"
)
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
continue
if (
tenant_id not in self.slack_bot_tokens
or slack_bot_tokens != self.slack_bot_tokens[tenant_id]
):
if tenant_id in self.slack_bot_tokens:
logger.info(
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
)
else:
search_settings = get_current_search_settings(
db_session
)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
self.slack_bot_tokens[tenant_id] = slack_bot_tokens
if self.socket_clients.get(tenant_id):
asyncio.run(self.socket_clients[tenant_id].close())
self.start_socket_client(tenant_id, slack_bot_tokens)
except KvKeyNotFoundError:
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if self.socket_clients.get(tenant_id):
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
except Exception as e:
logger.exception(f"Error handling tenant {tenant_id}: {e}")
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def send_heartbeats(self) -> None:
current_time = int(time.time())
logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} tenants")
for tenant_id in self.tenant_ids:
redis_client = get_redis_client(tenant_id=tenant_id)
heartbeat_key = (
f"{DanswerRedisLocks.SLACK_BOT_HEARTBEAT_PREFIX}:{self.pod_id}"
)
redis_client.set(
heartbeat_key, current_time, ex=TENANT_HEARTBEAT_EXPIRATION
)
def start_socket_client(
self, tenant_id: str | None, slack_bot_tokens: SlackBotTokens
) -> None:
logger.info(f"Starting socket client for tenant {tenant_id}")
socket_client = _get_socket_client(slack_bot_tokens, tenant_id)
# Append the event handler
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
# Establish a WebSocket connection to the Socket Mode servers
logger.info(f"Connecting socket client for tenant {tenant_id}")
socket_client.connect()
self.socket_clients[tenant_id] = socket_client
logger.info(f"Started SocketModeClient for tenant {tenant_id}")
def stop_socket_clients(self) -> None:
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
for tenant_id, client in self.socket_clients.items():
if client:
asyncio.run(client.close())
logger.info(f"Stopped SocketModeClient for tenant {tenant_id}")
def shutdown(self, signum: int | None, frame: FrameType | None) -> None:
if not self.running:
return
logger.info("Shutting down gracefully")
self.running = False
self._shutdown_event.set()
# Stop all socket clients
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
self.stop_socket_clients()
# Release locks for all tenants
logger.info(f"Releasing locks for {len(self.tenant_ids)} tenants")
for tenant_id in self.tenant_ids:
try:
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client.delete(DanswerRedisLocks.SLACK_BOT_LOCK)
logger.info(f"Released lock for tenant {tenant_id}")
except Exception as e:
logger.error(f"Error releasing lock for tenant {tenant_id}: {e}")
# Wait for background threads to finish (with timeout)
logger.info("Waiting for background threads to finish...")
self.acquire_thread.join(timeout=5)
self.heartbeat_thread.join(timeout=5)
logger.info("Shutdown complete")
sys.exit(0)
def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -> bool:
"""True to keep going, False to ignore this Slack request"""
if req.type == "events_api":
@@ -424,7 +172,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
message_subtype = event.get("subtype")
if message_subtype not in [None, "file_share"]:
channel_specific_logger.info(
f"Ignoring message with subtype '{message_subtype}' since it is a special message type"
f"Ignoring message with subtype '{message_subtype}' since is is a special message type"
)
return False
@@ -499,7 +247,7 @@ def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) ->
)
query_event_id, _, _ = decompose_action_id(feedback_id)
logger.info(f"Successfully handled QA feedback for event: {query_event_id}")
logger.notice(f"Successfully handled QA feedback for event: {query_event_id}")
def build_request_details(
@@ -521,14 +269,14 @@ def build_request_details(
msg = remove_danswer_bot_tag(msg, client=client.web_client)
if DANSWER_BOT_REPHRASE_MESSAGE:
logger.info(f"Rephrasing Slack message. Original message: {msg}")
logger.notice(f"Rephrasing Slack message. Original message: {msg}")
try:
msg = rephrase_slack_message(msg)
logger.info(f"Rephrased message: {msg}")
logger.notice(f"Rephrased message: {msg}")
except Exception as e:
logger.error(f"Error while trying to rephrase the Slack message: {e}")
else:
logger.info(f"Received Slack message: {msg}")
logger.notice(f"Received Slack message: {msg}")
if tagged:
logger.debug("User tagged DanswerBot")
@@ -729,21 +477,94 @@ def _get_socket_client(
)
def _initialize_socket_client(socket_client: TenantSocketModeClient) -> None:
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
# Establish a WebSocket connection to the Socket Mode servers
logger.notice(f"Listening for messages from Slack {socket_client.tenant_id }...")
socket_client.connect()
# Follow the guide (https://docs.danswer.dev/slack_bot_setup) to set up
# the slack bot in your workspace, and then add the bot to any channels you want to
# try and answer questions for. Running this file will setup Danswer to listen to all
# messages in those channels and attempt to answer them. As of now, it will only respond
# to messages sent directly in the channel - it will not respond to messages sent within a
# thread.
#
# NOTE: we are using Web Sockets so that you can run this from within a firewalled VPC
# without issue.
if __name__ == "__main__":
# Initialize the tenant handler which will manage tenant connections
logger.info("Starting SlackbotHandler")
tenant_handler = SlackbotHandler()
slack_bot_tokens: dict[str | None, SlackBotTokens] = {}
socket_clients: dict[str | None, TenantSocketModeClient] = {}
set_is_ee_based_on_env_variable()
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
download_nltk_data()
try:
# Keep the main thread alive
while tenant_handler.running:
time.sleep(1)
while True:
try:
tenant_ids = get_all_tenant_ids() # Function to retrieve all tenant IDs
except Exception:
logger.exception("Fatal error in main thread")
tenant_handler.shutdown(None, None)
for tenant_id in tenant_ids:
with get_session_with_tenant(tenant_id) as db_session:
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id or "public")
latest_slack_bot_tokens = fetch_tokens()
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
if (
tenant_id not in slack_bot_tokens
or latest_slack_bot_tokens != slack_bot_tokens[tenant_id]
):
if tenant_id in slack_bot_tokens:
logger.notice(
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
)
else:
# Initial setup for this tenant
search_settings = get_current_search_settings(
db_session
)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
slack_bot_tokens[tenant_id] = latest_slack_bot_tokens
# potentially may cause a message to be dropped, but it is complicated
# to avoid + (1) if the user is changing tokens, they are likely okay with some
# "migration downtime" and (2) if a single message is lost it is okay
# as this should be a very rare occurrence
if tenant_id in socket_clients:
socket_clients[tenant_id].close()
socket_client = _get_socket_client(
latest_slack_bot_tokens, tenant_id
)
# Initialize socket client for this tenant. Each tenant has its own
# socket client, allowing for multiple concurrent connections (one
# per tenant) with the tenant ID wrapped in the socket model client.
# Each `connect` stores websocket connection in a separate thread.
_initialize_socket_client(socket_client)
socket_clients[tenant_id] = socket_client
except KvKeyNotFoundError:
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if tenant_id in socket_clients:
socket_clients[tenant_id].disconnect()
del socket_clients[tenant_id]
del slack_bot_tokens[tenant_id]
# Wait before checking for updates
Event().wait(timeout=60)
except Exception:
logger.exception("An error occurred outside of main event loop")
time.sleep(60)

View File

@@ -14,7 +14,6 @@ from sqlalchemy.orm import Session
from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserRole
from danswer.db.api_key import get_api_key_email_pattern
from danswer.db.engine import get_async_session
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.models import AccessToken
@@ -36,16 +35,12 @@ def get_default_admin_user_emails() -> list[str]:
return get_default_admin_user_emails_fn()
def get_total_users_count(db_session: Session) -> int:
def get_total_users(db_session: Session) -> int:
"""
Returns the total number of users in the system.
This is the sum of users and invited users.
"""
user_count = (
db_session.query(User)
.filter(~User.email.endswith(get_api_key_email_pattern())) # type: ignore
.count()
)
user_count = db_session.query(User).count()
invited_users = len(get_invited_users())
return user_count + invited_users

View File

@@ -388,7 +388,7 @@ def get_chat_messages_by_session(
)
if prefetch_tool_calls:
stmt = stmt.options(joinedload(ChatMessage.tool_call))
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
result = db_session.scalars(stmt).unique().all()
else:
result = db_session.scalars(stmt).all()
@@ -474,7 +474,7 @@ def create_new_chat_message(
alternate_assistant_id: int | None = None,
# Maps the citation number [n] to the DB SearchDoc
citations: dict[int, int] | None = None,
tool_call: ToolCall | None = None,
tool_calls: list[ToolCall] | None = None,
commit: bool = True,
reserved_message_id: int | None = None,
overridden_model: str | None = None,
@@ -494,7 +494,7 @@ def create_new_chat_message(
existing_message.message_type = message_type
existing_message.citations = citations
existing_message.files = files
existing_message.tool_call = tool_call
existing_message.tool_calls = tool_calls if tool_calls else []
existing_message.error = error
existing_message.alternate_assistant_id = alternate_assistant_id
existing_message.overridden_model = overridden_model
@@ -513,7 +513,7 @@ def create_new_chat_message(
message_type=message_type,
citations=citations,
files=files,
tool_call=tool_call,
tool_calls=tool_calls if tool_calls else [],
error=error,
alternate_assistant_id=alternate_assistant_id,
overridden_model=overridden_model,
@@ -749,13 +749,14 @@ def translate_db_message_to_chat_message_detail(
time_sent=chat_message.time_sent,
citations=chat_message.citations,
files=chat_message.files or [],
tool_call=ToolCallFinalResult(
tool_name=chat_message.tool_call.tool_name,
tool_args=chat_message.tool_call.tool_arguments,
tool_result=chat_message.tool_call.tool_result,
)
if chat_message.tool_call
else None,
tool_calls=[
ToolCallFinalResult(
tool_name=tool_call.tool_name,
tool_args=tool_call.tool_arguments,
tool_result=tool_call.tool_result,
)
for tool_call in chat_message.tool_calls
],
alternate_assistant_id=chat_message.alternate_assistant_id,
overridden_model=chat_message.overridden_model,
)

View File

@@ -282,32 +282,3 @@ def mark_ccpair_as_pruned(cc_pair_id: int, db_session: Session) -> None:
cc_pair.last_pruned = datetime.now(timezone.utc)
db_session.commit()
def mark_cc_pair_as_permissions_synced(
db_session: Session, cc_pair_id: int, start_time: datetime | None
) -> None:
stmt = select(ConnectorCredentialPair).where(
ConnectorCredentialPair.id == cc_pair_id
)
cc_pair = db_session.scalar(stmt)
if cc_pair is None:
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
cc_pair.last_time_perm_sync = start_time
db_session.commit()
def mark_cc_pair_as_external_group_synced(db_session: Session, cc_pair_id: int) -> None:
stmt = select(ConnectorCredentialPair).where(
ConnectorCredentialPair.id == cc_pair_id
)
cc_pair = db_session.scalar(stmt)
if cc_pair is None:
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
# The sync time can be marked after it ran because all group syncs
# are run in full, not polling for changes.
# If this changes, we need to update this function.
cc_pair.last_time_external_group_sync = datetime.now(timezone.utc)
db_session.commit()

View File

@@ -25,8 +25,8 @@ from danswer.db.models import UserGroup__ConnectorCredentialPair
from danswer.db.models import UserRole
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit
from ee.danswer.external_permissions.sync_params import check_if_valid_sync_source
logger = setup_logger()
@@ -76,10 +76,8 @@ def _add_user_filters(
.where(~UG__CCpair.user_group_id.in_(user_groups))
.correlate(ConnectorCredentialPair)
)
where_clause |= ConnectorCredentialPair.creator_id == user.id
else:
where_clause |= ConnectorCredentialPair.access_type == AccessType.PUBLIC
where_clause |= ConnectorCredentialPair.access_type == AccessType.SYNC
return stmt.where(where_clause)
@@ -353,11 +351,7 @@ def add_credential_to_connector(
raise HTTPException(status_code=404, detail="Connector does not exist")
if access_type == AccessType.SYNC:
if not fetch_ee_implementation_or_noop(
"danswer.external_permissions.sync_params",
"check_if_valid_sync_source",
noop_return_value=True,
)(connector.source):
if not check_if_valid_sync_source(connector.source):
raise HTTPException(
status_code=400,
detail=f"Connector of type {connector.source} does not support SYNC access type",
@@ -389,7 +383,6 @@ def add_credential_to_connector(
)
association = ConnectorCredentialPair(
creator_id=user.id if user else None,
connector_id=connector_id,
credential_id=credential_id,
name=cc_pair_name,
@@ -445,10 +438,7 @@ def remove_credential_from_connector(
)
if association is not None:
fetch_ee_implementation_or_noop(
"danswer.db.external_perm",
"delete_user__ext_group_for_cc_pair__no_commit",
)(
delete_user__ext_group_for_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=association.id,
)

View File

@@ -10,7 +10,10 @@ from sqlalchemy.sql.expression import or_
from danswer.auth.schemas import UserRole
from danswer.configs.constants import DocumentSource
from danswer.connectors.google_utils.shared_constants import (
from danswer.connectors.gmail.constants import (
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.db.models import ConnectorCredentialPair
@@ -421,15 +424,25 @@ def cleanup_google_drive_credentials(db_session: Session) -> None:
db_session.commit()
def delete_service_account_credentials(
user: User | None, db_session: Session, source: DocumentSource
def delete_gmail_service_account_credentials(
user: User | None, db_session: Session
) -> None:
credentials = fetch_credentials(db_session=db_session, user=user)
for credential in credentials:
if (
credential.credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY)
and credential.source == source
if credential.credential_json.get(
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
):
db_session.delete(credential)
db_session.commit()
def delete_google_drive_service_account_credentials(
user: User | None, db_session: Session
) -> None:
credentials = fetch_credentials(db_session=db_session, user=user)
for credential in credentials:
if credential.credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY):
db_session.delete(credential)
db_session.commit()

View File

@@ -19,7 +19,6 @@ from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import null
from danswer.configs.constants import DEFAULT_BOOST
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.feedback import delete_document_feedback_for_documents__no_commit
@@ -47,21 +46,13 @@ def count_documents_by_needs_sync(session: Session) -> int:
"""Get the count of all documents where:
1. last_modified is newer than last_synced
2. last_synced is null (meaning we've never synced)
AND the document has a relationship with a connector/credential pair
TODO: The documents without a relationship with a connector/credential pair
should be cleaned up somehow eventually.
This function executes the query and returns the count of
documents matching the criteria."""
count = (
session.query(func.count(DbDocument.id.distinct()))
session.query(func.count())
.select_from(DbDocument)
.join(
DocumentByConnectorCredentialPair,
DbDocument.id == DocumentByConnectorCredentialPair.id,
)
.filter(
or_(
DbDocument.last_modified > DbDocument.last_synced,
@@ -100,22 +91,6 @@ def construct_document_select_for_connector_credential_pair_by_needs_sync(
return stmt
def get_all_documents_needing_vespa_sync_for_cc_pair(
db_session: Session, cc_pair_id: int
) -> list[DbDocument]:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id, db_session=db_session
)
if not cc_pair:
raise ValueError(f"No CC pair found with ID: {cc_pair_id}")
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
cc_pair.connector_id, cc_pair.credential_id
)
return list(db_session.scalars(stmt).all())
def construct_document_select_for_connector_credential_pair(
connector_id: int, credential_id: int | None = None
) -> Select:
@@ -129,21 +104,6 @@ def construct_document_select_for_connector_credential_pair(
return stmt
def get_documents_for_cc_pair(
db_session: Session,
cc_pair_id: int,
) -> list[DbDocument]:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id, db_session=db_session
)
if not cc_pair:
raise ValueError(f"No CC pair found with ID: {cc_pair_id}")
stmt = construct_document_select_for_connector_credential_pair(
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
)
return list(db_session.scalars(stmt).all())
def get_document_ids_for_connector_credential_pair(
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
) -> list[str]:
@@ -308,7 +268,7 @@ def get_access_info_for_documents(
return db_session.execute(stmt).all() # type: ignore
def _upsert_documents(
def upsert_documents(
db_session: Session,
document_metadata_batch: list[DocumentMetadata],
initial_boost: int = DEFAULT_BOOST,
@@ -346,8 +306,6 @@ def _upsert_documents(
]
)
# This does not update the permissions of the document if
# the document already exists.
on_conflict_stmt = insert_stmt.on_conflict_do_update(
index_elements=["id"], # Conflict target
set_={
@@ -364,7 +322,7 @@ def _upsert_documents(
db_session.commit()
def _upsert_document_by_connector_credential_pair(
def upsert_document_by_connector_credential_pair(
db_session: Session, document_metadata_batch: list[DocumentMetadata]
) -> None:
"""NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause."""
@@ -446,8 +404,8 @@ def upsert_documents_complete(
db_session: Session,
document_metadata_batch: list[DocumentMetadata],
) -> None:
_upsert_documents(db_session, document_metadata_batch)
_upsert_document_by_connector_credential_pair(db_session, document_metadata_batch)
upsert_documents(db_session, document_metadata_batch)
upsert_document_by_connector_credential_pair(db_session, document_metadata_batch)
logger.info(
f"Upserted {len(document_metadata_batch)} document store entries into DB"
)
@@ -505,6 +463,7 @@ def delete_documents_complete__no_commit(
db_session: Session, document_ids: list[str]
) -> None:
"""This completely deletes the documents from the db, including all foreign key relationships"""
logger.info(f"Deleting {len(document_ids)} documents from the DB")
delete_documents_by_connector_credential_pair__no_commit(db_session, document_ids)
delete_document_feedback_for_documents__no_commit(
document_ids=document_ids, db_session=db_session

View File

@@ -29,7 +29,6 @@ from danswer.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
from danswer.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
from danswer.configs.app_configs import POSTGRES_DB
from danswer.configs.app_configs import POSTGRES_HOST
from danswer.configs.app_configs import POSTGRES_IDLE_SESSIONS_TIMEOUT
from danswer.configs.app_configs import POSTGRES_PASSWORD
from danswer.configs.app_configs import POSTGRES_POOL_PRE_PING
from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE
@@ -38,10 +37,10 @@ from danswer.configs.app_configs import POSTGRES_USER
from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from danswer.utils.logger import setup_logger
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -189,13 +188,6 @@ class SqlEngine:
return ""
return cls._app_name
@classmethod
def reset_engine(cls) -> None:
with cls._lock:
if cls._engine:
cls._engine.dispose()
cls._engine = None
def get_all_tenant_ids() -> list[str] | list[None]:
if not MULTI_TENANT:
@@ -317,14 +309,8 @@ async def get_async_session_with_tenant(
try:
# Set the search_path to the tenant's schema
await session.execute(text(f'SET search_path = "{tenant_id}"'))
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
await session.execute(
text(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
)
except Exception:
logger.exception("Error setting search_path.")
except Exception as e:
logger.error(f"Error setting search_path: {str(e)}")
# You can choose to re-raise the exception or handle it
# Here, we'll re-raise to prevent proceeding with an incorrect session
raise
@@ -332,79 +318,47 @@ async def get_async_session_with_tenant(
yield session
@contextmanager
def get_session_with_default_tenant() -> Generator[Session, None, None]:
"""
Get a database session using the current tenant ID from the context variable.
"""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
with get_session_with_tenant(tenant_id) as session:
yield session
@contextmanager
def get_session_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
"""
Generate a database session for a specific tenant.
This function:
1. Sets the database schema to the specified tenant's schema.
2. Preserves the tenant ID across the session.
3. Reverts to the previous tenant ID after the session is closed.
4. Uses the default schema if no tenant ID is provided.
"""
"""Generate a database session bound to a connection with the appropriate tenant schema set."""
engine = get_sqlalchemy_engine()
# Store the previous tenant ID
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
if tenant_id is None:
tenant_id = POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
else:
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
event.listen(engine, "checkout", set_search_path_on_checkout)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
try:
# Establish a raw connection
with engine.connect() as connection:
# Access the raw DBAPI connection and set the search_path
dbapi_connection = connection.connection
# Establish a raw connection
with engine.connect() as connection:
# Access the raw DBAPI connection and set the search_path
dbapi_connection = connection.connection
# Set the search_path outside of any transaction
cursor = dbapi_connection.cursor()
# Set the search_path outside of any transaction
cursor = dbapi_connection.cursor()
try:
cursor.execute(f'SET search_path = "{tenant_id}"')
finally:
cursor.close()
# Bind the session to the connection
with Session(bind=connection, expire_on_commit=False) as session:
try:
cursor.execute(f'SET search_path = "{tenant_id}"')
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
cursor.execute(
text(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
)
yield session
finally:
cursor.close()
# Bind the session to the connection
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
finally:
# Reset search_path to default after the session is used
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()
finally:
# Restore the previous tenant ID
CURRENT_TENANT_ID_CONTEXTVAR.set(previous_tenant_id)
# Reset search_path to default after the session is used
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()
def set_search_path_on_checkout(

View File

@@ -219,7 +219,7 @@ def mark_attempt_partially_succeeded(
def mark_attempt_failed(
index_attempt_id: int,
index_attempt: IndexAttempt,
db_session: Session,
failure_reason: str = "Unknown",
full_exception_trace: str | None = None,
@@ -227,7 +227,7 @@ def mark_attempt_failed(
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt_id)
.where(IndexAttempt.id == index_attempt.id)
.with_for_update()
).scalar_one()

View File

@@ -126,8 +126,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# if specified, controls the assistants that are shown to the user + their order
# if not specified, all assistants are shown
chosen_assistants: Mapped[list[int] | None] = mapped_column(
postgresql.JSONB(), nullable=True, default=None
chosen_assistants: Mapped[list[int]] = mapped_column(
postgresql.JSONB(), nullable=False, default=[-2, -1, 0]
)
visible_assistants: Mapped[list[int]] = mapped_column(
postgresql.JSONB(), nullable=False, default=[]
@@ -135,9 +135,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
hidden_assistants: Mapped[list[int]] = mapped_column(
postgresql.JSONB(), nullable=False, default=[]
)
recent_assistants: Mapped[list[dict]] = mapped_column(
postgresql.JSONB(), nullable=False, default=list, server_default="[]"
)
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
TIMESTAMPAware(timezone=True), nullable=True
@@ -173,11 +170,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
)
# Whether the user has logged in via web. False if user has only used Danswer through Slack bot
has_web_login: Mapped[bool] = mapped_column(Boolean, default=True)
cc_pairs: Mapped[list["ConnectorCredentialPair"]] = relationship(
"ConnectorCredentialPair",
back_populates="creator",
primaryjoin="User.id == foreign(ConnectorCredentialPair.creator_id)",
)
class InputPrompt(Base):
@@ -425,9 +417,6 @@ class ConnectorCredentialPair(Base):
last_time_perm_sync: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
last_time_external_group_sync: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
# Time finished, not used for calculating backend jobs which uses time started (created)
last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), default=None
@@ -460,14 +449,6 @@ class ConnectorCredentialPair(Base):
"IndexAttempt", back_populates="connector_credential_pair"
)
# the user id of the user that created this cc pair
creator_id: Mapped[UUID | None] = mapped_column(nullable=True)
creator: Mapped["User"] = relationship(
"User",
back_populates="cc_pairs",
primaryjoin="foreign(ConnectorCredentialPair.creator_id) == remote(User.id)",
)
class Document(Base):
__tablename__ = "document"
@@ -753,10 +734,9 @@ class IndexAttempt(Base):
full_exception_trace: Mapped[str | None] = mapped_column(Text, default=None)
# Nullable because in the past, we didn't allow swapping out embedding models live
search_settings_id: Mapped[int] = mapped_column(
ForeignKey("search_settings.id", ondelete="SET NULL"),
nullable=True,
ForeignKey("search_settings.id"),
nullable=False,
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
@@ -776,7 +756,7 @@ class IndexAttempt(Base):
"ConnectorCredentialPair", back_populates="index_attempts"
)
search_settings: Mapped[SearchSettings | None] = relationship(
search_settings: Mapped[SearchSettings] = relationship(
"SearchSettings", back_populates="index_attempts"
)
@@ -937,15 +917,10 @@ class ToolCall(Base):
tool_arguments: Mapped[dict[str, JSON_ro]] = mapped_column(postgresql.JSONB())
tool_result: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
message_id: Mapped[int | None] = mapped_column(
ForeignKey("chat_message.id"), nullable=False
)
message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
# Update the relationship
message: Mapped["ChatMessage"] = relationship(
"ChatMessage",
back_populates="tool_call",
uselist=False,
"ChatMessage", back_populates="tool_calls"
)
@@ -1076,13 +1051,12 @@ class ChatMessage(Base):
secondary=ChatMessage__SearchDoc.__table__,
back_populates="chat_messages",
)
tool_call: Mapped["ToolCall"] = relationship(
# NOTE: Should always be attached to the `assistant` message.
# represents the tool calls used to generate this message
tool_calls: Mapped[list["ToolCall"]] = relationship(
"ToolCall",
back_populates="message",
uselist=False,
)
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
"StandardAnswer",
secondary=ChatMessage__StandardAnswer.__table__,
@@ -1340,6 +1314,7 @@ class StarterMessage(TypedDict):
in Postgres"""
name: str
description: str
message: str

View File

@@ -743,4 +743,5 @@ def delete_persona_by_name(
)
db_session.execute(stmt)
db_session.commit()

View File

@@ -12,7 +12,7 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.db.engine import get_session_with_default_tenant
from danswer.db.engine import get_session_with_tenant
from danswer.db.llm import fetch_embedding_provider
from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import IndexAttempt
@@ -152,7 +152,7 @@ def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
if db_session is None:
with get_session_with_default_tenant() as db_session:
with get_session_with_tenant() as db_session:
search_settings = get_current_search_settings(db_session)
else:
search_settings = get_current_search_settings(db_session)

View File

@@ -14,6 +14,7 @@ from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.search_settings import update_search_settings_status
from danswer.key_value_store.factory import get_kv_store
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -22,14 +23,7 @@ logger = setup_logger()
def check_index_swap(db_session: Session) -> SearchSettings | None:
"""Get count of cc-pairs and count of successful index_attempts for the
new model grouped by connector + credential, if it's the same, then assume
new index is done building. If so, swap the indices and expire the old one.
Returns None if search settings did not change, or the old search settings if they
did change.
"""
old_search_settings = None
new index is done building. If so, swap the indices and expire the old one."""
# Default CC-pair created for Ingestion API unused here
all_cc_pairs = get_connector_credential_pairs(db_session)
cc_pair_count = max(len(all_cc_pairs) - 1, 0)
@@ -49,9 +43,9 @@ def check_index_swap(db_session: Session) -> SearchSettings | None:
if cc_pair_count == 0 or cc_pair_count == unique_cc_indexings:
# Swap indices
current_search_settings = get_current_search_settings(db_session)
now_old_search_settings = get_current_search_settings(db_session)
update_search_settings_status(
search_settings=current_search_settings,
search_settings=now_old_search_settings,
new_status=IndexModelStatus.PAST,
db_session=db_session,
)
@@ -73,6 +67,6 @@ def check_index_swap(db_session: Session) -> SearchSettings | None:
for cc_pair in all_cc_pairs:
resync_cc_pair(cc_pair, db_session=db_session)
old_search_settings = current_search_settings
return old_search_settings
if MULTI_TENANT:
return now_old_search_settings
return None

View File

@@ -1,111 +0,0 @@
from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.configs.constants import TokenRateLimitScope
from danswer.db.models import TokenRateLimit
from danswer.db.models import TokenRateLimit__UserGroup
from danswer.server.token_rate_limits.models import TokenRateLimitArgs
def fetch_all_user_token_rate_limits(
db_session: Session,
enabled_only: bool = False,
ordered: bool = True,
) -> Sequence[TokenRateLimit]:
query = select(TokenRateLimit).where(
TokenRateLimit.scope == TokenRateLimitScope.USER
)
if enabled_only:
query = query.where(TokenRateLimit.enabled.is_(True))
if ordered:
query = query.order_by(TokenRateLimit.created_at.desc())
return db_session.scalars(query).all()
def fetch_all_global_token_rate_limits(
db_session: Session,
enabled_only: bool = False,
ordered: bool = True,
) -> Sequence[TokenRateLimit]:
query = select(TokenRateLimit).where(
TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
)
if enabled_only:
query = query.where(TokenRateLimit.enabled.is_(True))
if ordered:
query = query.order_by(TokenRateLimit.created_at.desc())
token_rate_limits = db_session.scalars(query).all()
return token_rate_limits
def insert_user_token_rate_limit(
db_session: Session,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = TokenRateLimit(
enabled=token_rate_limit_settings.enabled,
token_budget=token_rate_limit_settings.token_budget,
period_hours=token_rate_limit_settings.period_hours,
scope=TokenRateLimitScope.USER,
)
db_session.add(token_limit)
db_session.commit()
return token_limit
def insert_global_token_rate_limit(
db_session: Session,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = TokenRateLimit(
enabled=token_rate_limit_settings.enabled,
token_budget=token_rate_limit_settings.token_budget,
period_hours=token_rate_limit_settings.period_hours,
scope=TokenRateLimitScope.GLOBAL,
)
db_session.add(token_limit)
db_session.commit()
return token_limit
def update_token_rate_limit(
db_session: Session,
token_rate_limit_id: int,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
if token_limit is None:
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")
token_limit.enabled = token_rate_limit_settings.enabled
token_limit.token_budget = token_rate_limit_settings.token_budget
token_limit.period_hours = token_rate_limit_settings.period_hours
db_session.commit()
return token_limit
def delete_token_rate_limit(
db_session: Session,
token_rate_limit_id: int,
) -> None:
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
if token_limit is None:
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")
db_session.query(TokenRateLimit__UserGroup).filter(
TokenRateLimit__UserGroup.rate_limit_id == token_rate_limit_id
).delete()
db_session.delete(token_limit)
db_session.commit()

View File

@@ -24,13 +24,6 @@ def get_tool_by_id(tool_id: int, db_session: Session) -> Tool:
return tool
def get_tool_by_name(tool_name: str, db_session: Session) -> Tool:
tool = db_session.scalar(select(Tool).where(Tool.name == tool_name))
if not tool:
raise ValueError("Tool by specified name does not exist")
return tool
def create_tool(
name: str,
description: str | None,
@@ -44,7 +37,7 @@ def create_tool(
description=description,
in_code_tool_id=None,
openapi_schema=openapi_schema,
custom_headers=[header.model_dump() for header in custom_headers]
custom_headers=[header.dict() for header in custom_headers]
if custom_headers
else [],
user_id=user_id,

Some files were not shown because too many files have changed in this diff Show More